From f515b7b4b421eef3a8f0074c70f9b277d36edf2d Mon Sep 17 00:00:00 2001 From: "justin.hsu" Date: Sat, 9 Nov 2024 05:59:12 +0800 Subject: [PATCH] feat: add voyage reranker --- requirements.txt | 2 +- src/automate.py | 15 ++++++ src/flask_app.py | 3 +- src/utils/weaviatexreranker.py | 97 ++++++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 src/utils/weaviatexreranker.py diff --git a/requirements.txt b/requirements.txt index 9e124df..743e682 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ flask_restx==1.3.0 python-dateutil redis==5.0.8 flask-httpauth==4.8.0 -ckip-transformers +voyageai diff --git a/src/automate.py b/src/automate.py index 14ae9e2..b2b8a04 100644 --- a/src/automate.py +++ b/src/automate.py @@ -1,5 +1,6 @@ import requests import json +import time # Import time module for timing # Load questions from the JSON file with open('data/questions_example.json', 'r', encoding='utf-8') as file: @@ -9,7 +10,11 @@ url = "http://127.0.0.1:5000/api/chat" +total_start_time = time.time() # Start timing for the entire process + for question in questions: + question_start_time = time.time() # Start timing for each question + # Send POST request response = requests.post(url, json=question) @@ -28,6 +33,16 @@ print("成功取得 JSON:", response_json) else: print("請求失敗,狀態碼:", response.status_code) + + # Calculate and print time for each question + question_end_time = time.time() + question_duration = question_end_time - question_start_time + print(f"QID: {qid} - 花費時間: {question_duration:.2f} 秒") + +# Calculate and print total time +total_end_time = time.time() +total_duration = total_end_time - total_start_time +print(f"全部題目處理完成,總共花費時間: {total_duration:.2f} 秒") # Save the output data to a new JSON file with open('data/pred_retrieve.json', 'w', encoding='utf-8') as output_file: diff --git a/src/flask_app.py b/src/flask_app.py index aee7fd9..2e4b0f5 100644 --- a/src/flask_app.py +++ b/src/flask_app.py @@ -5,7 +5,8 @@ from flask_limiter import Limiter from flask_limiter.util import get_remote_address from flask_restx import Api, Resource, fields -from utils.weaviate_op import search_do +# from utils.weaviate_op import search_do +from utils.weaviatexreranker import search_do from werkzeug.security import check_password_hash, generate_password_hash config, logger, CONFIG_PATH = config_log.setup_config_and_logging() diff --git a/src/utils/weaviatexreranker.py b/src/utils/weaviatexreranker.py new file mode 100644 index 0000000..e314954 --- /dev/null +++ b/src/utils/weaviatexreranker.py @@ -0,0 +1,97 @@ +import os +import weaviate +from langchain.embeddings import OpenAIEmbeddings +import utils.config_log as config_log +import voyageai + +# 載入設定檔案和日誌設定 +config, logger, CONFIG_PATH = config_log.setup_config_and_logging() +config.read(CONFIG_PATH) + +# 從 config 中取得 Weaviate URL 和 API 金鑰 +wea_url = config.get('Weaviate', 'weaviate_url') +voyage_api_key = config.get('VoyageAI', 'api_key') +PROPERTIES = ['pid', 'content'] + +# 設定 OpenAI API 金鑰 +os.environ['OPENAI_API_KEY'] = config.get('OpenAI', 'api_key') + +class WeaviateSemanticSearch: + def __init__(self, classnm): + self.url = wea_url + self.embeddings = OpenAIEmbeddings(chunk_size=1, model='text-embedding-3-large') + self.client = weaviate.Client(url=wea_url) + self.classnm = classnm + + def hybrid_search(self, query, source, num, alpha): + query_vector = self.embeddings.embed_query(query) + vector_str = ','.join(map(str, query_vector)) + + where_conditions = ' '.join([ + f'{{path: ["pid"], operator: Equal, valueText: "{pid}"}}' for pid in source + ]) + + gql_query = f""" + {{ + Get {{ + {self.classnm}(where: {{ + operator: Or, + operands: [{where_conditions}] + }}, hybrid: {{ + query: "{query}", + vector: [{vector_str}], + alpha: {alpha} + }}, limit: {num}) {{ + pid + content + _additional {{ + distance + score + }} + }} + }} + }} + """ + search_results = self.client.query.raw(gql_query) + + if 'errors' in search_results: + raise Exception(search_results['errors'][0]['message']) + + results = search_results['data']['Get'][self.classnm] + return results + + +def rerank_with_voyage(query, documents, pids, api_key): + vo = voyageai.Client(api_key=api_key) + reranking = vo.rerank(query, documents, model="rerank-2", top_k=1) + top_result = reranking.results[0] + + # 根據內容找到相對應的 pid + top_pid = pids[documents.index(top_result.document)] + return {'pid': top_pid, 'relevance_score': top_result.relevance_score} + + +def search_do(question, category, source, alpha): + if category == "finance": + vdb_named = "Financedev" + elif category == "insurance": + vdb_named = "Insurancedev" + else: + vdb_named = "Faqdev" + + searcher = WeaviateSemanticSearch(vdb_named) + # 從 Weaviate 取得前 100 筆結果 + top_100_results = searcher.hybrid_search(question, source, 100, alpha=alpha) + + # 準備文件和 pid 列表供 rerank 使用 + documents = [result['content'] for result in top_100_results] + pids = [result['pid'] for result in top_100_results] + + # 使用 VoyageAI 重新排序,並取得排名最高的 pid + top_reranked_result = rerank_with_voyage(question, documents, pids, voyage_api_key) + + print("最相關文件的 PID:") + print(f"PID: {top_reranked_result['pid']}") + print(f"相關性分數: {top_reranked_result['relevance_score']}") + + return top_reranked_result['pid']