Skip to content

Commit

Permalink
Update/faiss (#13)
Browse files Browse the repository at this point in the history
* faiss is replaced with chroma for database
changed files for replacement: get_rag_chain.py, populate_database.py
other files changed for naming and refactoring
  • Loading branch information
maryamteimouri authored Sep 3, 2024
1 parent 4b7c072 commit 2cb5769
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 94 deletions.
3 changes: 1 addition & 2 deletions local_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from parameters import load_config
global DATA_PATH
load_config('test')
from parameters import CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL, PROMPT_TEMPLATE, DATA_PATH, REPHRASING_PROMPT, STANDALONE_PROMPT, ROUTER_DECISION_PROMPT
from parameters import DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL, PROMPT_TEMPLATE, DATA_PATH, REPHRASING_PROMPT, STANDALONE_PROMPT, ROUTER_DECISION_PROMPT
from get_llm_function import get_llm_function
from get_rag_chain import get_rag_chain
from ConversationalRagChain import ConversationalRagChain
Expand Down Expand Up @@ -82,7 +82,6 @@ def get_Chat_response(query):
"chat_history": []
}
res = rag_conv._call(inputs)
print(res['metadatas'])
output = jsonify({
'response': res['result'],
'context': res['context'],
Expand Down
10 changes: 5 additions & 5 deletions python_script/config.json
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
{
"default": {
"data_path": "data/documents/default",
"chroma_root_path": "data/chroma/default",
"database_root_path": "data/database/default",
"embedding_model": "voyage-law-2",
"llm_model": "gpt-3.5-turbo"
},
"seus": {
"data_path": "data/documents/ship_data",
"chroma_root_path": "data/chroma/ship_chroma",
"database_root_path": "data/database/ship",
"embedding_model": "openai",
"llm_model": "gpt-3.5-turbo"
},
"arch-en": {
"data_path": "data/documents/arch_data-en",
"chroma_root_path": "data/chroma/arch_data-en_chroma",
"database_root_path": "data/database/arch_data-en",
"embedding_model": "openai",
"llm_model": "gpt-3.5-turbo"
},
"arch-ru": {
"data_path": "data/documents/arch_data-ru",
"chroma_root_path": "data/chroma/arch_data-ru_chroma",
"database_root_path": "data/database/arch_data-ru",
"embedding_model": "openai",
"llm_model": "gpt-3.5-turbo"
},
"test": {
"data_path": "data/documents/test_data",
"chroma_root_path": "data/chroma/test_chroma",
"database_root_path": "data/database/test",
"embedding_model": "openai",
"llm_model": "gpt-3.5-turbo"
}
Expand Down
34 changes: 21 additions & 13 deletions python_script/get_rag_chain.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from parameters import CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL
from parameters import DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL

from get_embedding_function import get_embedding_function
from get_llm_function import get_llm_function
from populate_database import find_chroma_path
from populate_database import find_database_path

from langchain.vectorstores.chroma import Chroma
from langchain.chains import create_history_aware_retriever
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.vectorstores import FAISS

def get_rag_chain(params = None):
"""
Expand All @@ -21,7 +21,7 @@ def get_rag_chain(params = None):
Parameters:
params (dict, optional): A dictionary of configuration parameters.
- chroma_root_path (str): The root path for Chroma data storage.
- database_root_path (str): The root path for data storage.
- embedding_model (str): The model name for the embedding function.
- llm_model (str): The model name for the language model.
- search_type (str): The type of search to perform. Options are:
Expand All @@ -41,7 +41,7 @@ def get_rag_chain(params = None):
"""

default_params = {
"chroma_root_path": CHROMA_ROOT_PATH,
"database_root_path": DATABASE_ROOT_PATH,
"embedding_model": EMBEDDING_MODEL,
"llm_model": LLM_MODEL,
"search_type": "similarity",
Expand All @@ -59,25 +59,34 @@ def get_rag_chain(params = None):
params = {**default_params, **params}

try:
required_keys = ["chroma_root_path", "embedding_model", "llm_model"]
required_keys = ["database_root_path", "embedding_model", "llm_model"]
for key in required_keys:
if key not in params:
raise NameError(f"Required setting '{key}' not defined.")

embedding_model = get_embedding_function(model_name=params["embedding_model"])
llm = get_llm_function(model_name=params["llm_model"])
db = Chroma(persist_directory=find_chroma_path(model_name=params["embedding_model"], base_path=params["chroma_root_path"]), embedding_function=embedding_model)


# Load the FAISS index from disk
vector_store = FAISS.load_local(find_database_path(EMBEDDING_MODEL,DATABASE_ROOT_PATH)
, embedding_model, allow_dangerous_deserialization=True)

search_type = params["search_type"]
if search_type == "similarity":
retriever = db.as_retriever(search_type=search_type, search_kwargs={"k": params["similarity_doc_nb"]})
retriever = vector_store.as_retriever(search_type=search_type,
search_kwargs={"k": params["similarity_doc_nb"]})
elif search_type == "similarity_score_threshold":
retriever = db.as_retriever(search_type=search_type, search_kwargs={"k": params["max_chunk_return"],"score_threshold": params["score_threshold"]})
retriever = vector_store.as_retriever(search_type=search_type,
search_kwargs={"k": params["max_chunk_return"],
"score_threshold": params["score_threshold"]})
elif search_type == "mmr":
retriever = db.as_retriever(search_type=search_type, search_kwargs={"k": params["mmr_doc_nb"], "fetch_k": params["considered_chunk"], "lambda_mult": params["lambda_mult"]})
retriever = vector_store.as_retriever(search_type=search_type,
search_kwargs={"k": params["mmr_doc_nb"],
"fetch_k": params["considered_chunk"],
"lambda_mult": params["lambda_mult"]})
else:
raise ValueError("Invalid 'search_type' setting")

except NameError as e:
variable_name = str(e).split("'")[1]
raise NameError(f"{variable_name} isn't defined")
Expand Down Expand Up @@ -118,5 +127,4 @@ def get_rag_chain(params = None):
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(retriever, question_answer_chain)

return rag_chain
15 changes: 7 additions & 8 deletions python_script/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from langchain.prompts import PromptTemplate

DATA_PATH = None
CHROMA_ROOT_PATH = None
DATABASE_ROOT_PATH = None
EMBEDDING_MODEL = None
LLM_MODEL = None
PROMPT_TEMPLATE = None
Expand All @@ -19,7 +19,7 @@ def load_api_keys():
load_dotenv()
os.environ["HF_API_TOKEN"] = os.getenv("HF_API_TOKEN")
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
os.environ["VOYAGE_API_KEY"] = os.getenv("VOYAGE_API_KEY")
#os.environ["VOYAGE_API_KEY"] = os.getenv("VOYAGE_API_KEY")


def load_config(config_name = 'default', show_config = False):
Expand All @@ -31,7 +31,7 @@ def load_config(config_name = 'default', show_config = False):
{
"config_name": {
"data_path": "", # Path to the data folder
"chroma_root_path": "", # Path to the folder where the Chroma database will be stored
"database_root_path": "", # Path to the folder where the database will be stored
"embedding_model": "", # Model to use for embeddings (e.g., 'sentence-transformers/all-mpnet-base-v2', 'openai', 'voyage-law-2')
"llm_model": "", # Model to use for the language model (e.g., 'gpt-3.5-turbo', 'mistralai/Mistral-7B-Instruct-v0.1', 'nvidia/Llama3-ChatQA-1.5-8B')
}
Expand All @@ -48,7 +48,7 @@ def load_config(config_name = 'default', show_config = False):
- "mistralai/Mixtral-8x7B-Instruct-v0.1"
- "nvidia/Llama3-ChatQA-1.5-8B"
"""
global DATA_PATH, CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL
global DATA_PATH, DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL
try:
with open('config.json', 'r') as file:
config = json.load(file)
Expand All @@ -60,11 +60,10 @@ def load_config(config_name = 'default', show_config = False):
raise FileNotFoundError("The configuration file cannot be found in the specified paths.")
except json.JSONDecodeError:
raise ValueError("The configuration file is present but contains a JSON format error.")

selected_config = config[config_name]

DATA_PATH = selected_config['data_path']
CHROMA_ROOT_PATH = selected_config['chroma_root_path']
DATABASE_ROOT_PATH = selected_config['database_root_path']
EMBEDDING_MODEL = selected_config['embedding_model']
LLM_MODEL = selected_config['llm_model']

Expand All @@ -79,11 +78,11 @@ def print_config():
Print the current configuration settings.
This function prints the values of the global configuration parameters.
"""
global DATA_PATH, CHROMA_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL
global DATA_PATH, DATABASE_ROOT_PATH, EMBEDDING_MODEL, LLM_MODEL

print("\nCurrent Configuration Settings:\n")
print(f"Data Path: {DATA_PATH}")
print(f"Chroma Root Path: {CHROMA_ROOT_PATH}")
print(f"Database Root Path: {DATABASE_ROOT_PATH}")
print(f"Embedding Model: {EMBEDDING_MODEL}")
print(f"Language Model: {LLM_MODEL}\n")

Expand Down
Loading

0 comments on commit 2cb5769

Please sign in to comment.