Skip to content

Commit

Permalink
feat: add voyage reranker
Browse files Browse the repository at this point in the history
  • Loading branch information
JustinHsu1019 committed Nov 8, 2024
1 parent 2c85338 commit f515b7b
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 2 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ flask_restx==1.3.0
python-dateutil
redis==5.0.8
flask-httpauth==4.8.0
ckip-transformers
voyageai
15 changes: 15 additions & 0 deletions src/automate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import requests
import json
import time # Import time module for timing

# Load questions from the JSON file
with open('data/questions_example.json', 'r', encoding='utf-8') as file:
Expand All @@ -9,7 +10,11 @@

url = "http://127.0.0.1:5000/api/chat"

total_start_time = time.time() # Start timing for the entire process

for question in questions:
question_start_time = time.time() # Start timing for each question

# Send POST request
response = requests.post(url, json=question)

Expand All @@ -28,6 +33,16 @@
print("成功取得 JSON:", response_json)
else:
print("請求失敗,狀態碼:", response.status_code)

# Calculate and print time for each question
question_end_time = time.time()
question_duration = question_end_time - question_start_time
print(f"QID: {qid} - 花費時間: {question_duration:.2f} 秒")

# Calculate and print total time
total_end_time = time.time()
total_duration = total_end_time - total_start_time
print(f"全部題目處理完成,總共花費時間: {total_duration:.2f} 秒")

# Save the output data to a new JSON file
with open('data/pred_retrieve.json', 'w', encoding='utf-8') as output_file:
Expand Down
3 changes: 2 additions & 1 deletion src/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_restx import Api, Resource, fields
from utils.weaviate_op import search_do
# from utils.weaviate_op import search_do
from utils.weaviatexreranker import search_do
from werkzeug.security import check_password_hash, generate_password_hash

config, logger, CONFIG_PATH = config_log.setup_config_and_logging()
Expand Down
97 changes: 97 additions & 0 deletions src/utils/weaviatexreranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import weaviate
from langchain.embeddings import OpenAIEmbeddings
import utils.config_log as config_log
import voyageai

# 載入設定檔案和日誌設定
config, logger, CONFIG_PATH = config_log.setup_config_and_logging()
config.read(CONFIG_PATH)

# 從 config 中取得 Weaviate URL 和 API 金鑰
wea_url = config.get('Weaviate', 'weaviate_url')
voyage_api_key = config.get('VoyageAI', 'api_key')
PROPERTIES = ['pid', 'content']

# 設定 OpenAI API 金鑰
os.environ['OPENAI_API_KEY'] = config.get('OpenAI', 'api_key')

class WeaviateSemanticSearch:
def __init__(self, classnm):
self.url = wea_url
self.embeddings = OpenAIEmbeddings(chunk_size=1, model='text-embedding-3-large')
self.client = weaviate.Client(url=wea_url)
self.classnm = classnm

def hybrid_search(self, query, source, num, alpha):
query_vector = self.embeddings.embed_query(query)
vector_str = ','.join(map(str, query_vector))

where_conditions = ' '.join([
f'{{path: ["pid"], operator: Equal, valueText: "{pid}"}}' for pid in source
])

gql_query = f"""
{{
Get {{
{self.classnm}(where: {{
operator: Or,
operands: [{where_conditions}]
}}, hybrid: {{
query: "{query}",
vector: [{vector_str}],
alpha: {alpha}
}}, limit: {num}) {{
pid
content
_additional {{
distance
score
}}
}}
}}
}}
"""
search_results = self.client.query.raw(gql_query)

if 'errors' in search_results:
raise Exception(search_results['errors'][0]['message'])

results = search_results['data']['Get'][self.classnm]
return results


def rerank_with_voyage(query, documents, pids, api_key):
vo = voyageai.Client(api_key=api_key)
reranking = vo.rerank(query, documents, model="rerank-2", top_k=1)
top_result = reranking.results[0]

# 根據內容找到相對應的 pid
top_pid = pids[documents.index(top_result.document)]
return {'pid': top_pid, 'relevance_score': top_result.relevance_score}


def search_do(question, category, source, alpha):
if category == "finance":
vdb_named = "Financedev"
elif category == "insurance":
vdb_named = "Insurancedev"
else:
vdb_named = "Faqdev"

searcher = WeaviateSemanticSearch(vdb_named)
# 從 Weaviate 取得前 100 筆結果
top_100_results = searcher.hybrid_search(question, source, 100, alpha=alpha)

# 準備文件和 pid 列表供 rerank 使用
documents = [result['content'] for result in top_100_results]
pids = [result['pid'] for result in top_100_results]

# 使用 VoyageAI 重新排序,並取得排名最高的 pid
top_reranked_result = rerank_with_voyage(question, documents, pids, voyage_api_key)

print("最相關文件的 PID:")
print(f"PID: {top_reranked_result['pid']}")
print(f"相關性分數: {top_reranked_result['relevance_score']}")

return top_reranked_result['pid']

0 comments on commit f515b7b

Please sign in to comment.