Skip to content

Commit

Permalink
update flask rou
Browse files Browse the repository at this point in the history
  • Loading branch information
JustinHsu1019 committed Nov 8, 2024
1 parent 88aa643 commit 0b04e8c
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 7 deletions.
80 changes: 80 additions & 0 deletions data/get_best_alpha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import requests
import json
from collections import defaultdict

# Load questions from the JSON file
with open('data/questions_example.json', 'r', encoding='utf-8') as file:
questions = json.load(file)['questions']

# Load ground truth data
with open('data/ground_truths_example.json', 'r', encoding='utf-8') as f:
ground_truths = json.load(f)["ground_truths"]

# Dictionary to hold the best alpha and accuracy
best_alpha = 0.0
best_accuracy = 0

# Loop through alpha values from 0.0 to 1.0
for alpha in [round(x * 0.1, 1) for x in range(11)]:
output_data = {"answers": []} # Reset output format with "answers" array

url = "http://127.0.0.1:5000/api/chat"

# Send each question to the API with the current alpha
for question in questions:
# Add the alpha key to the question payload
question_with_alpha = {**question, "alpha": alpha}

# Send POST request
response = requests.post(url, json=question_with_alpha)

if response.status_code == 200:
response_json = response.json()
qid = question.get("qid")
retrieve = response_json.get("retrieve")

# Append formatted result to the answers array
output_data["answers"].append({
"qid": qid,
"retrieve": retrieve
})
else:
print(f"請求失敗,狀態碼: {response.status_code},Alpha 值: {alpha}")

# Save predictions for the current alpha
pred_file = f'data/pred_retrieve_alpha_{alpha}.json'
with open(pred_file, 'w', encoding='utf-8') as output_file:
json.dump(output_data, output_file, ensure_ascii=False, indent=4)

# Load predictions for comparison
pred_dict = {item["qid"]: item["retrieve"] for item in output_data["answers"]}

# Initialize counters and data structures for accuracy calculation
correct_count = 0
category_counts = defaultdict(lambda: {"correct": 0, "total": 0})

# Compare predictions to ground truth
for ground in ground_truths:
qid = ground["qid"]
category = ground["category"]
correct_retrieve = ground["retrieve"]
predicted_retrieve = pred_dict.get(qid)

if predicted_retrieve == correct_retrieve:
correct_count += 1
category_counts[category]["correct"] += 1

category_counts[category]["total"] += 1

# Calculate accuracy for the current alpha
accuracy = correct_count / len(ground_truths)
print("Corrrect count: ", correct_count)
print(f"Alpha: {alpha}, 正確率: {accuracy:.2%}")

# Track the best alpha and accuracy
if accuracy > best_accuracy:
best_alpha = alpha
best_accuracy = accuracy

# Output the best alpha and accuracy
print(f"最佳 Alpha 值: {best_alpha}, 準確率: {best_accuracy:.2%}")
5 changes: 4 additions & 1 deletion src/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,24 @@ def post(self):
source = request.json.get('source')
question = request.json.get('query')
category = request.json.get('category')
# alpha = request.json.get('alpha')

# {
# "qid": 1,
# "source": [442, 115, 440, 196, 431, 392, 14, 51],
# "query": "匯款銀行及中間行所收取之相關費用由誰負擔?",
# "category": "insurance"
# },

alpha = 0.5

if not question:
response = jsonify({'qid': '1', 'retrieve': '1'})
response.status_code = 200
return response
else:
try:
response = search_do(question, category, source)
response = search_do(question, category, source, alpha)
response = {
'qid': qid,
'retrieve': int(response)
Expand Down
13 changes: 7 additions & 6 deletions src/utils/weaviate_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain.embeddings import OpenAIEmbeddings

import utils.config_log as config_log
# import config_log as config_log

config, logger, CONFIG_PATH = config_log.setup_config_and_logging()
config.read(CONFIG_PATH)
Expand Down Expand Up @@ -72,16 +73,16 @@ def hybrid_search(self, query, source, num, alpha):
return results


def search_do(question, category, source):
def search_do(question, category, source, alpha):
if category == "finance":
vdb_named = "Finance"
vdb_named = "Financedev"
elif category == "insurance":
vdb_named = "Insurance"
vdb_named = "Insurancedev"
else:
vdb_named = "Faq"
vdb_named = "Faqdev"

searcher = WeaviateSemanticSearch(vdb_named)
results = searcher.hybrid_search(question, source, 1, alpha=0.5)
results = searcher.hybrid_search(question, source, 1, alpha=alpha)

result_li = []
for _, result in enumerate(results, 1):
Expand All @@ -92,7 +93,7 @@ def search_do(question, category, source):


if __name__ == '__main__':
vdb = "Faq"
vdb = "Insurancedev"
client = WeaviateSemanticSearch(vdb)

# 統計筆數
Expand Down

0 comments on commit 0b04e8c

Please sign in to comment.