From 2cb57696df5af8adf4f8fd522e95868a7da4efe5 Mon Sep 17 00:00:00 2001 From: Maryam Teimouri Date: Tue, 3 Sep 2024 10:50:12 +0300 Subject: [PATCH] Update/faiss (#13) * faiss is replaced with chroma for database changed files for replacement: get_rag_chain.py, populate_database.py other files changed for naming and refactoring --- local_app.py | 3 +- python_script/config.json | 10 +-- python_script/get_rag_chain.py | 34 ++++--- python_script/parameters.py | 15 ++-- python_script/populate_database.py | 140 +++++++++++++++-------------- 5 files changed, 108 insertions(+), 94 deletions(-) diff --git a/local_app.py b/local_app.py index 6bb5b0a..3305dc8 100644 --- a/local_app.py +++ b/local_app.py @@ -9,7 +9,7 @@ from parameters import load_config global DATA_PATH load_config('test') -from parameters import CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL, PROMPT_TEMPLATE, DATA_PATH, REPHRASING_PROMPT, STANDALONE_PROMPT, ROUTER_DECISION_PROMPT +from parameters import DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL, PROMPT_TEMPLATE, DATA_PATH, REPHRASING_PROMPT, STANDALONE_PROMPT, ROUTER_DECISION_PROMPT from get_llm_function import get_llm_function from get_rag_chain import get_rag_chain from ConversationalRagChain import ConversationalRagChain @@ -82,7 +82,6 @@ def get_Chat_response(query): "chat_history": [] } res = rag_conv._call(inputs) - print(res['metadatas']) output = jsonify({ 'response': res['result'], 'context': res['context'], diff --git a/python_script/config.json b/python_script/config.json index abc6539..211a99a 100644 --- a/python_script/config.json +++ b/python_script/config.json @@ -1,31 +1,31 @@ { "default": { "data_path": "data/documents/default", - "chroma_root_path": "data/chroma/default", + "database_root_path": "data/database/default", "embedding_model": "voyage-law-2", "llm_model": "gpt-3.5-turbo" }, "seus": { "data_path": "data/documents/ship_data", - "chroma_root_path": "data/chroma/ship_chroma", + "database_root_path": "data/database/ship", "embedding_model": "openai", "llm_model": "gpt-3.5-turbo" }, "arch-en": { "data_path": "data/documents/arch_data-en", - "chroma_root_path": "data/chroma/arch_data-en_chroma", + "database_root_path": "data/database/arch_data-en", "embedding_model": "openai", "llm_model": "gpt-3.5-turbo" }, "arch-ru": { "data_path": "data/documents/arch_data-ru", - "chroma_root_path": "data/chroma/arch_data-ru_chroma", + "database_root_path": "data/database/arch_data-ru", "embedding_model": "openai", "llm_model": "gpt-3.5-turbo" }, "test": { "data_path": "data/documents/test_data", - "chroma_root_path": "data/chroma/test_chroma", + "database_root_path": "data/database/test", "embedding_model": "openai", "llm_model": "gpt-3.5-turbo" } diff --git a/python_script/get_rag_chain.py b/python_script/get_rag_chain.py index b14aaba..104b348 100644 --- a/python_script/get_rag_chain.py +++ b/python_script/get_rag_chain.py @@ -1,14 +1,14 @@ -from parameters import CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL +from parameters import DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL from get_embedding_function import get_embedding_function from get_llm_function import get_llm_function -from populate_database import find_chroma_path +from populate_database import find_database_path -from langchain.vectorstores.chroma import Chroma from langchain.chains import create_history_aware_retriever from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_community.vectorstores import FAISS def get_rag_chain(params = None): """ @@ -21,7 +21,7 @@ def get_rag_chain(params = None): Parameters: params (dict, optional): A dictionary of configuration parameters. - - chroma_root_path (str): The root path for Chroma data storage. + - database_root_path (str): The root path for data storage. - embedding_model (str): The model name for the embedding function. - llm_model (str): The model name for the language model. - search_type (str): The type of search to perform. Options are: @@ -41,7 +41,7 @@ def get_rag_chain(params = None): """ default_params = { - "chroma_root_path": CHROMA_ROOT_PATH, + "database_root_path": DATABASE_ROOT_PATH, "embedding_model": EMBEDDING_MODEL, "llm_model": LLM_MODEL, "search_type": "similarity", @@ -59,25 +59,34 @@ def get_rag_chain(params = None): params = {**default_params, **params} try: - required_keys = ["chroma_root_path", "embedding_model", "llm_model"] + required_keys = ["database_root_path", "embedding_model", "llm_model"] for key in required_keys: if key not in params: raise NameError(f"Required setting '{key}' not defined.") embedding_model = get_embedding_function(model_name=params["embedding_model"]) llm = get_llm_function(model_name=params["llm_model"]) - db = Chroma(persist_directory=find_chroma_path(model_name=params["embedding_model"], base_path=params["chroma_root_path"]), embedding_function=embedding_model) - + + # Load the FAISS index from disk + vector_store = FAISS.load_local(find_database_path(EMBEDDING_MODEL,DATABASE_ROOT_PATH) + , embedding_model, allow_dangerous_deserialization=True) + search_type = params["search_type"] if search_type == "similarity": - retriever = db.as_retriever(search_type=search_type, search_kwargs={"k": params["similarity_doc_nb"]}) + retriever = vector_store.as_retriever(search_type=search_type, + search_kwargs={"k": params["similarity_doc_nb"]}) elif search_type == "similarity_score_threshold": - retriever = db.as_retriever(search_type=search_type, search_kwargs={"k": params["max_chunk_return"],"score_threshold": params["score_threshold"]}) + retriever = vector_store.as_retriever(search_type=search_type, + search_kwargs={"k": params["max_chunk_return"], + "score_threshold": params["score_threshold"]}) elif search_type == "mmr": - retriever = db.as_retriever(search_type=search_type, search_kwargs={"k": params["mmr_doc_nb"], "fetch_k": params["considered_chunk"], "lambda_mult": params["lambda_mult"]}) + retriever = vector_store.as_retriever(search_type=search_type, + search_kwargs={"k": params["mmr_doc_nb"], + "fetch_k": params["considered_chunk"], + "lambda_mult": params["lambda_mult"]}) else: raise ValueError("Invalid 'search_type' setting") - + except NameError as e: variable_name = str(e).split("'")[1] raise NameError(f"{variable_name} isn't defined") @@ -118,5 +127,4 @@ def get_rag_chain(params = None): question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain(retriever, question_answer_chain) - return rag_chain \ No newline at end of file diff --git a/python_script/parameters.py b/python_script/parameters.py index a89b9cf..3ef30e3 100644 --- a/python_script/parameters.py +++ b/python_script/parameters.py @@ -4,7 +4,7 @@ from langchain.prompts import PromptTemplate DATA_PATH = None -CHROMA_ROOT_PATH = None +DATABASE_ROOT_PATH = None EMBEDDING_MODEL = None LLM_MODEL = None PROMPT_TEMPLATE = None @@ -19,7 +19,7 @@ def load_api_keys(): load_dotenv() os.environ["HF_API_TOKEN"] = os.getenv("HF_API_TOKEN") os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") - os.environ["VOYAGE_API_KEY"] = os.getenv("VOYAGE_API_KEY") + #os.environ["VOYAGE_API_KEY"] = os.getenv("VOYAGE_API_KEY") def load_config(config_name = 'default', show_config = False): @@ -31,7 +31,7 @@ def load_config(config_name = 'default', show_config = False): { "config_name": { "data_path": "", # Path to the data folder - "chroma_root_path": "", # Path to the folder where the Chroma database will be stored + "database_root_path": "", # Path to the folder where the database will be stored "embedding_model": "", # Model to use for embeddings (e.g., 'sentence-transformers/all-mpnet-base-v2', 'openai', 'voyage-law-2') "llm_model": "", # Model to use for the language model (e.g., 'gpt-3.5-turbo', 'mistralai/Mistral-7B-Instruct-v0.1', 'nvidia/Llama3-ChatQA-1.5-8B') } @@ -48,7 +48,7 @@ def load_config(config_name = 'default', show_config = False): - "mistralai/Mixtral-8x7B-Instruct-v0.1" - "nvidia/Llama3-ChatQA-1.5-8B" """ - global DATA_PATH, CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL + global DATA_PATH, DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL try: with open('config.json', 'r') as file: config = json.load(file) @@ -60,11 +60,10 @@ def load_config(config_name = 'default', show_config = False): raise FileNotFoundError("The configuration file cannot be found in the specified paths.") except json.JSONDecodeError: raise ValueError("The configuration file is present but contains a JSON format error.") - selected_config = config[config_name] DATA_PATH = selected_config['data_path'] - CHROMA_ROOT_PATH = selected_config['chroma_root_path'] + DATABASE_ROOT_PATH = selected_config['database_root_path'] EMBEDDING_MODEL = selected_config['embedding_model'] LLM_MODEL = selected_config['llm_model'] @@ -79,11 +78,11 @@ def print_config(): Print the current configuration settings. This function prints the values of the global configuration parameters. """ - global DATA_PATH, CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL + global DATA_PATH, DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL print("\nCurrent Configuration Settings:\n") print(f"Data Path: {DATA_PATH}") - print(f"Chroma Root Path: {CHROMA_ROOT_PATH}") + print(f"Database Root Path: {DATABASE_ROOT_PATH}") print(f"Embedding Model: {EMBEDDING_MODEL}") print(f"Language Model: {LLM_MODEL}\n") diff --git a/python_script/populate_database.py b/python_script/populate_database.py index b3560e1..589613a 100644 --- a/python_script/populate_database.py +++ b/python_script/populate_database.py @@ -2,22 +2,31 @@ import os import shutil -import fitz from tqdm import tqdm +import faiss +from typing import List +from pathlib import Path +import logging + +from get_embedding_function import get_embedding_function from langchain.schema.document import Document from langchain.document_loaders.pdf import PyPDFDirectoryLoader -from langchain.document_loaders.pdf import PyPDFLoader from langchain.document_loaders.pdf import PDFPlumberLoader -from typing import List, Union -from pathlib import Path -import logging +from llama_index.core import SimpleDirectoryReader +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_community.vectorstores import FAISS +from langchain_community.docstore.in_memory import InMemoryDocstore + logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def main(): - parser = argparse.ArgumentParser(description="This script manages the database for the application. You can use it to clear the database, reset it, or populate it with a specific configuration that can be specified in the config.json file.") + parser = argparse.ArgumentParser(description="""This script manages the database for the application. + You can use it to clear the database, reset it, or populate + it with a specific configuration that can be specified in the + config.json file.""") parser.add_argument("--config", type=str, help="Enter the name of the config you want to populate your database.") parser.add_argument("--reset", action="store_true", help="Reset the database.") parser.add_argument("--clear", action="store_true", help="Clear the database.") @@ -28,7 +37,7 @@ def main(): print("Clearing Database...") if args.config: load_config(args.config) - subfolder_name = "chroma_{}".format(EMBEDDING_MODEL) + subfolder_name = "database_{}".format(EMBEDDING_MODEL) clear_database(subfolder_name) else: clear_database() @@ -37,7 +46,7 @@ def main(): if args.config: load_config(args.config) if args.reset: - subfolder_name = "chroma_{}".format(EMBEDDING_MODEL) + subfolder_name = "database_{}".format(EMBEDDING_MODEL) print("Reseting Database...") try: clear_database(subfolder_name) @@ -46,16 +55,54 @@ def main(): documents = load_documents() chunks = split_documents(documents) - add_to_chroma(chunks) + add_to_database(chunks) + +def add_to_database(chunks: list[Document]): + + # Assume all valid embeddings have the same dimension + index = faiss.IndexFlatL2(len(get_embedding_function(EMBEDDING_MODEL).embed_query("hello world"))) + + vector_store = FAISS( + embedding_function=get_embedding_function(EMBEDDING_MODEL), + index=index, + docstore= InMemoryDocstore(), + index_to_docstore_id={} + ) + existing_ids = [] + + index_file = find_database_path(EMBEDDING_MODEL,DATABASE_ROOT_PATH) + "index.faiss" + if os.path.exists(index_file): + vector_store = FAISS.load_local(find_database_path(EMBEDDING_MODEL,DATABASE_ROOT_PATH), + get_embedding_function(EMBEDDING_MODEL), + allow_dangerous_deserialization=True) + # Add or Update the documents. + existing_ids = vector_store.index_to_docstore_id.values() + print("existing_ids", vector_store.index_to_docstore_id.values()) + + chunks_with_ids = calculate_chunk_ids(chunks) + + # Only add documents that don't exist in the DB. + new_chunks = [chunk for chunk in chunks_with_ids if chunk.metadata["id"] not in existing_ids] + batch_size = 1000 + + if new_chunks: + with tqdm(total=len(new_chunks), desc="Adding chunks") as pbar: + for i in range(0,len(new_chunks), batch_size): + batch = new_chunks[i:i + batch_size] + new_chunk_ids = [chunk.metadata["id"] for chunk in batch] + vector_store.add_documents(batch, ids=new_chunk_ids) + pbar.update(len(batch)) + + vector_store.save_local(find_database_path(EMBEDDING_MODEL,DATABASE_ROOT_PATH)) def load_config(config_name): """ Load and print the parameters entered for config_name into the config.json file. """ - global DATA_PATH, CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL, PROMPT_TEMPLATE + global DATA_PATH, DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL, PROMPT_TEMPLATE from parameters import load_config as ld ld(config_name, show_config=True) - from parameters import DATA_PATH, CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL, PROMPT_TEMPLATE + from parameters import DATA_PATH, DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL, PROMPT_TEMPLATE def load_documents(): """ @@ -66,7 +113,6 @@ def load_documents(): #TODO Make something better maybe #TODO Implement other document type. llamaindex tool can do it """ - from llama_index.core import SimpleDirectoryReader langchain_documents = [] llama_documents = [] @@ -82,11 +128,12 @@ def load_documents(): langchain_document_loader = ProgressPyPDFDirectoryLoader(DATA_PATH) for doc in tqdm(langchain_document_loader.load(), desc="PDFs loaded"): doc.metadata.pop('file_path',None) - print(doc.metadata) langchain_documents.append(doc) documents = langchain_documents + convert_llamaindexdoc_to_langchaindoc(llama_documents) - print(f"Loaded {len(langchain_documents)} page{'s' if len(langchain_documents) > 1 else ''} from PDF document{'s' if len(langchain_documents) > 1 else ''}, {len(llama_documents)} item{'s' if len(llama_documents) > 1 else ''} from TXT/DOCX document{'s' if len(llama_documents) > 1 else ''}.\nTotal items: {len(documents)}.\n") + print(f"Loaded {len(langchain_documents)} page{'s' if len(langchain_documents) > 1 else ''} from PDF document{ \ + 's' if len(langchain_documents) > 1 else ''}, {len(llama_documents)} item{'s' \ + if len(llama_documents) > 1 else ''} from TXT/DOCX document{'s' if len(llama_documents) > 1 else ''}.\nTotal items: {len(documents)}.\n") return documents def convert_llamaindexdoc_to_langchaindoc(documents: list[Document]): @@ -102,7 +149,6 @@ def split_documents(documents: list[Document]): """ Split documents into smaller chunks """ - from langchain_text_splitters import RecursiveCharacterTextSplitter text_splitter = RecursiveCharacterTextSplitter( chunk_size=800, chunk_overlap=80, @@ -111,44 +157,6 @@ def split_documents(documents: list[Document]): ) return text_splitter.split_documents(documents) - -def add_to_chroma(chunks: list[Document]): - """ - Load the chroma database - Check if there are new documents in the documents database - Add them to the chroma database - """ - from langchain.vectorstores.chroma import Chroma - from get_embedding_function import get_embedding_function - # Load the existing database. - db = Chroma( - persist_directory=find_chroma_path(EMBEDDING_MODEL,CHROMA_ROOT_PATH), embedding_function=get_embedding_function(EMBEDDING_MODEL) - ) - - chunks_with_ids = calculate_chunk_ids(chunks) - - # Add or Update the documents. - existing_items = db.get(include=[]) - existing_ids = set(existing_items["ids"]) - print(f"Number of existing chunks in DB: {len(existing_ids)}") - - # Only add documents that don't exist in the DB. - new_chunks = [chunk for chunk in chunks_with_ids if chunk.metadata["id"] not in existing_ids] - batch_size = 1000 - - if new_chunks: - with tqdm(total=len(new_chunks), desc="Adding chunks") as pbar: - for i in range(0,len(new_chunks), batch_size): - batch = new_chunks[i:i + batch_size] - new_chunk_ids = [chunk.metadata["id"] for chunk in batch] - db.add_documents(batch, ids=new_chunk_ids) - db.persist() - pbar.update(len(batch)) - - else: - print("Done. No new chunks to add") - - def calculate_chunk_ids(chunks): """ Add metadata id to the chunk in the following format: @@ -172,32 +180,32 @@ def calculate_chunk_ids(chunks): chunk.metadata["id"] = chunk_id return chunks -def find_chroma_path(model_name, base_path): +def find_database_path(model_name, base_path): """ - Find the path to the chroma database corresponding to the Embedding model - Create the subfolder in the chroma root folder if not exists + Find the path to the database corresponding to the Embedding model + Create the subfolder in the database root folder if not exists """ if not model_name: raise ValueError("Model name can't be empty") if not base_path: try: - base_path = CHROMA_ROOT_PATH + base_path = DATABASE_ROOT_PATH except: - raise ValueError("The Chroma database root file is not populated") + raise ValueError("The database database root file is not populated") - model_path = os.path.join(base_path, f"chroma_{model_name}") + model_path = os.path.join(base_path, f"database_{model_name}") if not os.path.exists(model_path): os.makedirs(model_path) return model_path -def clear_database(chroma_subfolder_name = None): +def clear_database(database_subfolder_name = None): """ - Clear the folder if chroma_subfolder_name is set and exists + Clear the folder if database_subfolder_name is set and exists Clear the whole database if not set but ask the user before """ - if chroma_subfolder_name: - full_path = os.path.join(CHROMA_ROOT_PATH, chroma_subfolder_name) + if database_subfolder_name: + full_path = os.path.join(DATABASE_ROOT_PATH, database_subfolder_name) if os.path.exists(full_path): shutil.rmtree(full_path) print(f"The database in {full_path} has been successfully deleted.") @@ -205,15 +213,15 @@ def clear_database(chroma_subfolder_name = None): raise FolderNotFoundError(f"Folder {full_path} doesn't exist") else: print("\nExisting databases :\n\n") - subfolders = [f for f in os.listdir(CHROMA_ROOT_PATH) if os.path.isdir(os.path.join(CHROMA_ROOT_PATH, f))] + subfolders = [f for f in os.listdir(DATABASE_ROOT_PATH) if os.path.isdir(os.path.join(DATABASE_ROOT_PATH, f))] if not subfolders: - print(f"no subfolder found in {CHROMA_ROOT_PATH}\n\n") + print(f"no subfolder found in {DATABASE_ROOT_PATH}\n\n") for subfolder in subfolders: print(f"- {subfolder}") confirmation = input("Do you want to delete all the databases ? (yes/no) : ") if confirmation.lower() == 'yes': - shutil.rmtree(CHROMA_ROOT_PATH) + shutil.rmtree(DATABASE_ROOT_PATH) print("All databases cleared") else: print("Deletion cancelled.")