From 0b04e8c7ce68ce062fe4f1a745600e4523ccb17f Mon Sep 17 00:00:00 2001 From: "justin.hsu" Date: Sat, 9 Nov 2024 04:38:54 +0800 Subject: [PATCH] update flask rou --- data/get_best_alpha.py | 80 ++++++++++++++++++++++++++++++++++++++++ src/flask_app.py | 5 ++- src/utils/weaviate_op.py | 13 ++++--- 3 files changed, 91 insertions(+), 7 deletions(-) create mode 100644 data/get_best_alpha.py diff --git a/data/get_best_alpha.py b/data/get_best_alpha.py new file mode 100644 index 0000000..e6b7527 --- /dev/null +++ b/data/get_best_alpha.py @@ -0,0 +1,80 @@ +import requests +import json +from collections import defaultdict + +# Load questions from the JSON file +with open('data/questions_example.json', 'r', encoding='utf-8') as file: + questions = json.load(file)['questions'] + +# Load ground truth data +with open('data/ground_truths_example.json', 'r', encoding='utf-8') as f: + ground_truths = json.load(f)["ground_truths"] + +# Dictionary to hold the best alpha and accuracy +best_alpha = 0.0 +best_accuracy = 0 + +# Loop through alpha values from 0.0 to 1.0 +for alpha in [round(x * 0.1, 1) for x in range(11)]: + output_data = {"answers": []} # Reset output format with "answers" array + + url = "http://127.0.0.1:5000/api/chat" + + # Send each question to the API with the current alpha + for question in questions: + # Add the alpha key to the question payload + question_with_alpha = {**question, "alpha": alpha} + + # Send POST request + response = requests.post(url, json=question_with_alpha) + + if response.status_code == 200: + response_json = response.json() + qid = question.get("qid") + retrieve = response_json.get("retrieve") + + # Append formatted result to the answers array + output_data["answers"].append({ + "qid": qid, + "retrieve": retrieve + }) + else: + print(f"請求失敗,狀態碼: {response.status_code},Alpha 值: {alpha}") + + # Save predictions for the current alpha + pred_file = f'data/pred_retrieve_alpha_{alpha}.json' + with open(pred_file, 'w', encoding='utf-8') as output_file: + json.dump(output_data, output_file, ensure_ascii=False, indent=4) + + # Load predictions for comparison + pred_dict = {item["qid"]: item["retrieve"] for item in output_data["answers"]} + + # Initialize counters and data structures for accuracy calculation + correct_count = 0 + category_counts = defaultdict(lambda: {"correct": 0, "total": 0}) + + # Compare predictions to ground truth + for ground in ground_truths: + qid = ground["qid"] + category = ground["category"] + correct_retrieve = ground["retrieve"] + predicted_retrieve = pred_dict.get(qid) + + if predicted_retrieve == correct_retrieve: + correct_count += 1 + category_counts[category]["correct"] += 1 + + category_counts[category]["total"] += 1 + + # Calculate accuracy for the current alpha + accuracy = correct_count / len(ground_truths) + print("Corrrect count: ", correct_count) + print(f"Alpha: {alpha}, 正確率: {accuracy:.2%}") + + # Track the best alpha and accuracy + if accuracy > best_accuracy: + best_alpha = alpha + best_accuracy = accuracy + +# Output the best alpha and accuracy +print(f"最佳 Alpha 值: {best_alpha}, 準確率: {best_accuracy:.2%}") diff --git a/src/flask_app.py b/src/flask_app.py index 72e94c4..aee7fd9 100644 --- a/src/flask_app.py +++ b/src/flask_app.py @@ -68,6 +68,7 @@ def post(self): source = request.json.get('source') question = request.json.get('query') category = request.json.get('category') + # alpha = request.json.get('alpha') # { # "qid": 1, @@ -75,6 +76,8 @@ def post(self): # "query": "匯款銀行及中間行所收取之相關費用由誰負擔?", # "category": "insurance" # }, + + alpha = 0.5 if not question: response = jsonify({'qid': '1', 'retrieve': '1'}) @@ -82,7 +85,7 @@ def post(self): return response else: try: - response = search_do(question, category, source) + response = search_do(question, category, source, alpha) response = { 'qid': qid, 'retrieve': int(response) diff --git a/src/utils/weaviate_op.py b/src/utils/weaviate_op.py index 7a83158..471a06e 100644 --- a/src/utils/weaviate_op.py +++ b/src/utils/weaviate_op.py @@ -4,6 +4,7 @@ from langchain.embeddings import OpenAIEmbeddings import utils.config_log as config_log +# import config_log as config_log config, logger, CONFIG_PATH = config_log.setup_config_and_logging() config.read(CONFIG_PATH) @@ -72,16 +73,16 @@ def hybrid_search(self, query, source, num, alpha): return results -def search_do(question, category, source): +def search_do(question, category, source, alpha): if category == "finance": - vdb_named = "Finance" + vdb_named = "Financedev" elif category == "insurance": - vdb_named = "Insurance" + vdb_named = "Insurancedev" else: - vdb_named = "Faq" + vdb_named = "Faqdev" searcher = WeaviateSemanticSearch(vdb_named) - results = searcher.hybrid_search(question, source, 1, alpha=0.5) + results = searcher.hybrid_search(question, source, 1, alpha=alpha) result_li = [] for _, result in enumerate(results, 1): @@ -92,7 +93,7 @@ def search_do(question, category, source): if __name__ == '__main__': - vdb = "Faq" + vdb = "Insurancedev" client = WeaviateSemanticSearch(vdb) # 統計筆數