Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refector: Repo structure (after competition) #2

Merged
merged 3 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# aicup-rag
# AI CUP 2024 玉山人工智慧公開挑戰賽-RAG與LLM在金融問答的應用

## Result

- Total: 38 / 487 Teams
- Leaderboard: 38 / 222

![AI Cup Result](img/aicup_result.png)

## Development Mode
To set up the development environment, follow these steps:
Expand Down
20 changes: 11 additions & 9 deletions data/conbine_result.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
import json

# 載入 aicup_noocr.json 和 aicup_ref.json
with open('data/aicup_noocr.json', 'r', encoding='utf-8') as file:
with open('data/aicup_noocr.json', encoding='utf-8') as file:
noocr_data = json.load(file)

with open('data/aicup_ref.json', 'r', encoding='utf-8') as file:
with open('data/aicup_ref.json', encoding='utf-8') as file:
ref_data = json.load(file)

# 建立 ref_data 的 dictionary,並檢查 content 是否為字串,再去除空格
ref_dict = {
(item["category"], item["pid"]): ''.join(item["content"].split()) if isinstance(item["content"], str) else item["content"]
(item['category'], item['pid']): ''.join(item['content'].split())
if isinstance(item['content'], str)
else item['content']
for item in ref_data
}

# 更新 noocr_data 中空的 content
for item in noocr_data:
category = item["category"]
pid = item["pid"]
content = item["content"]
category = item['category']
pid = item['pid']
content = item['content']

# 如果 content 是 string 並且為空,則從 ref_data 裡填入去掉空格的 content
if isinstance(content, str) and content == "":
if isinstance(content, str) and content == '':
if (category, pid) in ref_dict:
item["content"] = ref_dict[(category, pid)]
item['content'] = ref_dict[(category, pid)]

# 將結果寫入 aicup_noocr_sec.json
with open('data/aicup_noocr_sec.json', 'w', encoding='utf-8') as file:
json.dump(noocr_data, file, ensure_ascii=False, indent=4)

print("已完成比對並生成 aicup_noocr_sec.json,並移除轉入的 content 中的空格(如果 content 是字串)")
print('已完成比對並生成 aicup_noocr_sec.json,並移除轉入的 content 中的空格(如果 content 是字串)')
15 changes: 10 additions & 5 deletions data/read_data_noocr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import json
import os

import pdfplumber
from tqdm import tqdm


# 讀取單個PDF文件並返回其文本內容
def read_pdf(pdf_loc):
pdf = pdfplumber.open(pdf_loc)
Expand All @@ -14,24 +16,26 @@ def read_pdf(pdf_loc):
pdf.close()
return pdf_text


# 從指定資料夾載入PDF文件,並根據資料夾名稱設定category
def load_data_by_category(source_path, category):
pdf_files = [f for f in os.listdir(source_path) if f.endswith('.pdf')]
data = []
for file in tqdm(pdf_files):
pid = file.replace('.pdf', '') # 擷取檔案名稱作為pid
content = read_pdf(os.path.join(source_path, file)) # 讀取PDF內文
data.append({"category": category, "pid": pid, "content": content})
data.append({'category': category, 'pid': pid, 'content': content})
return data


# 主程式
def generate_json(output_path):
all_data = []

# 載入不同類別的PDF資料
source_paths = {
"finance": "reference/finance", # finance 資料夾的路徑
"insurance": "reference/insurance" # insurance 資料夾的路徑
'finance': 'reference/finance', # finance 資料夾的路徑
'insurance': 'reference/insurance', # insurance 資料夾的路徑
}

# 遍歷每個類別的資料夾並載入資料
Expand All @@ -43,6 +47,7 @@ def generate_json(output_path):
with open(output_path, 'w', encoding='utf8') as f:
json.dump(all_data, f, ensure_ascii=False, indent=4)


# 設定輸出路徑
output_path = 'data/aicup_noocr.json'
generate_json(output_path)
Binary file added img/aicup_result.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 25 additions & 19 deletions src/db_insert.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import time
import json
from langchain.text_splitter import RecursiveCharacterTextSplitter
import time

import utils.config_log as config_log
import weaviate
from langchain.text_splitter import RecursiveCharacterTextSplitter

config, logger, CONFIG_PATH = config_log.setup_config_and_logging()
config.read(CONFIG_PATH)
Expand All @@ -13,6 +14,7 @@
# Token limit for OpenAI model
TOKEN_LIMIT = 8192


class WeaviateManager:
def __init__(self, classnm):
self.url = wea_url
Expand All @@ -28,7 +30,11 @@ def check_class_exist(self):
'class': self.classnm,
'properties': [
{'name': 'pid', 'dataType': ['text']},
{'name': 'content', 'dataType': ['text'], "tokenization": "gse"}, # `gse` implements the "Jieba" algorithm, which is a popular Chinese text segmentation algorithm.
{
'name': 'content',
'dataType': ['text'],
'tokenization': 'gse',
}, # `gse` implements the "Jieba" algorithm, which is a popular Chinese text segmentation algorithm.
],
'vectorizer': 'text2vec-openai',
'moduleConfig': {
Expand All @@ -51,19 +57,19 @@ def insert_data(self, pid, content):
error_msg = str(e)
# 檢查是否是因為 token 長度過長
if 'maximum context length' in error_msg:
print(f"Content too long for pid: {pid}. Splitting content.")
return "TOO_LONG" # 特殊回傳值表達需要分割
print(f'Content too long for pid: {pid}. Splitting content.')
return 'TOO_LONG' # 特殊回傳值表達需要分割
elif '429' in error_msg:
print(f'Rate limit exceeded, retrying in 5 seconds... (Attempt {attempt + 1}/{max_retries})')
time.sleep(5)
else:
print(f"Unexpected Error for pid: {pid} - {error_msg}")
print(f'Unexpected Error for pid: {pid} - {error_msg}')
return False
except Exception as e:
print(f'Error inserting data for pid: {pid}, category: {self.classnm} - {str(e)}')
return False
# 超過最大重試次數
print(f"Failed to insert data for pid: {pid} after {max_retries} attempts.")
print(f'Failed to insert data for pid: {pid} after {max_retries} attempts.')
return False

def split_and_insert(self, pid, content, category):
Expand All @@ -73,10 +79,10 @@ def split_and_insert(self, pid, content, category):

# 逐段插入分割後的文本,保持相同的 pid 和 category
for idx, part in enumerate(split_content):
print(f"Inserting split content part {idx + 1} for pid: {pid}")
print(f'Inserting split content part {idx + 1} for pid: {pid}')
success = self.insert_data(pid, part)
if not success:
failed_records.append({"pid": pid, "category": category})
failed_records.append({'pid': pid, 'category': category})


if __name__ == '__main__':
Expand All @@ -90,32 +96,32 @@ def split_and_insert(self, pid, content, category):
pid = item['pid']
content = item['content']

if category == "faq":
classnm = "faqdev"
if category == 'faq':
classnm = 'faqdev'
content_str = json.dumps(content, ensure_ascii=False, indent=4)
elif category == "insurance":
classnm = "insurancedev"
elif category == 'insurance':
classnm = 'insurancedev'
content_str = content
elif category == "finance":
classnm = "financedev"
elif category == 'finance':
classnm = 'financedev'
content_str = json.dumps(content, ensure_ascii=False, indent=4) if isinstance(content, dict) else content
else:
print("Unknown category, skipping item.")
print('Unknown category, skipping item.')
continue

manager = WeaviateManager(classnm)
result = manager.insert_data(pid, content_str)

# 如果內容過長需要切割
if result == "TOO_LONG":
if result == 'TOO_LONG':
manager.split_and_insert(pid, content_str, category)
elif not result: # 如果失敗且非長度問題
failed_records.append({"pid": pid, "category": category})
failed_records.append({'pid': pid, 'category': category})

# 將失敗的資料寫入 JSON 檔案
if failed_records:
with open('failed_imports.json', 'w', encoding='utf-8') as f:
json.dump(failed_records, f, ensure_ascii=False, indent=4)
print("Failed records have been written to 'failed_imports.json'")
else:
print("All records imported successfully.")
print('All records imported successfully.')
10 changes: 4 additions & 6 deletions src/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_restx import Api, Resource, fields

# from utils.weaviate_op import search_do
from utils.weaviatexreranker import search_do
from werkzeug.security import check_password_hash, generate_password_hash
Expand Down Expand Up @@ -45,7 +46,7 @@ def verify_password(username, password):
'qid': fields.Integer(required=True, description='qid of the question'),
'source': fields.List(fields.Integer, required=True, description='source of the question'),
'query': fields.String(required=True, description='The message to the chatbot'),
'category': fields.String(required=True, description='The category of the question')
'category': fields.String(required=True, description='The category of the question'),
},
)

Expand Down Expand Up @@ -77,7 +78,7 @@ def post(self):
# "query": "匯款銀行及中間行所收取之相關費用由誰負擔?",
# "category": "insurance"
# },

alpha = 0.5

if not question:
Expand All @@ -87,10 +88,7 @@ def post(self):
else:
try:
response = search_do(question, category, source, alpha)
response = {
'qid': qid,
'retrieve': int(response)
}
response = {'qid': qid, 'retrieve': int(response)}

response = jsonify(response)

Expand Down
36 changes: 17 additions & 19 deletions src/tools/automate.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,49 @@
import requests
import json
import time # Import time module for timing

import requests

# Load questions from the JSON file
with open('data/questions_example.json', 'r', encoding='utf-8') as file:
with open('data/questions_example.json', encoding='utf-8') as file:
questions = json.load(file)['questions']

output_data = {"answers": []} # Initialize output format with "answers" array
output_data = {'answers': []} # Initialize output format with "answers" array

url = "http://127.0.0.1:5000/api/chat"
url = 'http://127.0.0.1:5000/api/chat'

total_start_time = time.time() # Start timing for the entire process

for question in questions:
question_start_time = time.time() # Start timing for each question

# Send POST request
response = requests.post(url, json=question)

if response.status_code == 200:
response_json = response.json()

# Extract qid and retrieve from the API response
qid = question.get("qid") # Assuming each question has a unique "qid" field
retrieve = response_json.get("retrieve")
qid = question.get('qid') # Assuming each question has a unique "qid" field
retrieve = response_json.get('retrieve')

# Append formatted result to the answers array
output_data["answers"].append({
"qid": qid,
"retrieve": retrieve
})
print("成功取得 JSON:", response_json)
output_data['answers'].append({'qid': qid, 'retrieve': retrieve})
print('成功取得 JSON:', response_json)
else:
print("請求失敗,狀態碼:", response.status_code)
print('請求失敗,狀態碼:', response.status_code)

# Calculate and print time for each question
question_end_time = time.time()
question_duration = question_end_time - question_start_time
print(f"QID: {qid} - 花費時間: {question_duration:.2f} 秒")
print(f'QID: {qid} - 花費時間: {question_duration:.2f} 秒')

# Calculate and print total time
total_end_time = time.time()
total_duration = total_end_time - total_start_time
print(f"全部題目處理完成,總共花費時間: {total_duration:.2f} 秒")
print(f'全部題目處理完成,總共花費時間: {total_duration:.2f} 秒')

# Save the output data to a new JSON file
with open('data/pred_retrieve.json', 'w', encoding='utf-8') as output_file:
json.dump(output_data, output_file, ensure_ascii=False, indent=4)

print("合併輸出已保存到 pred_retrieve.json 文件中。")
print('合併輸出已保存到 pred_retrieve.json 文件中。')
26 changes: 13 additions & 13 deletions src/tools/checkans.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,39 @@
from collections import defaultdict

# Load ground truth data
with open('data/ground_truths_example.json', 'r') as f:
ground_truths = json.load(f)["ground_truths"]
with open('data/ground_truths_example.json') as f:
ground_truths = json.load(f)['ground_truths']

# Load predicted data with the new format
with open('data/pred_retrieve.json', 'r') as f:
pred_retrieves = json.load(f)["answers"]
with open('data/pred_retrieve.json') as f:
pred_retrieves = json.load(f)['answers']

# Create a dictionary from predictions for easy lookup
pred_dict = {item["qid"]: item["retrieve"] for item in pred_retrieves}
pred_dict = {item['qid']: item['retrieve'] for item in pred_retrieves}

# Initialize counters and data structures
incorrect_qids = []
correct_count = 0
category_counts = defaultdict(lambda: {"correct": 0, "total": 0})
category_counts = defaultdict(lambda: {'correct': 0, 'total': 0})

# Compare predictions to ground truth
for ground in ground_truths:
qid = ground["qid"]
category = ground["category"]
correct_retrieve = ground["retrieve"]
qid = ground['qid']
category = ground['category']
correct_retrieve = ground['retrieve']
predicted_retrieve = pred_dict.get(qid)

if predicted_retrieve == correct_retrieve:
correct_count += 1
category_counts[category]["correct"] += 1
category_counts[category]['correct'] += 1
else:
incorrect_qids.append(qid)

category_counts[category]["total"] += 1
category_counts[category]['total'] += 1

# Print results
print("錯誤的題目 QID:", incorrect_qids)
print(f"總正確題數: {correct_count} / {len(ground_truths)}")
print('錯誤的題目 QID:', incorrect_qids)
print(f'總正確題數: {correct_count} / {len(ground_truths)}')

for category, counts in category_counts.items():
print(f"類別 {category}: {counts['correct']} / {counts['total']}")
Loading
Loading