Skip to content

Commit

Permalink
fix: pre-commit issue
Browse files Browse the repository at this point in the history
  • Loading branch information
JustinHsu1019 committed Nov 10, 2024
1 parent dc9bfb2 commit 17c2a9f
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 117 deletions.
20 changes: 11 additions & 9 deletions data/conbine_result.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
import json

# 載入 aicup_noocr.json 和 aicup_ref.json
with open('data/aicup_noocr.json', 'r', encoding='utf-8') as file:
with open('data/aicup_noocr.json', encoding='utf-8') as file:
noocr_data = json.load(file)

with open('data/aicup_ref.json', 'r', encoding='utf-8') as file:
with open('data/aicup_ref.json', encoding='utf-8') as file:
ref_data = json.load(file)

# 建立 ref_data 的 dictionary,並檢查 content 是否為字串,再去除空格
ref_dict = {
(item["category"], item["pid"]): ''.join(item["content"].split()) if isinstance(item["content"], str) else item["content"]
(item['category'], item['pid']): ''.join(item['content'].split())
if isinstance(item['content'], str)
else item['content']
for item in ref_data
}

# 更新 noocr_data 中空的 content
for item in noocr_data:
category = item["category"]
pid = item["pid"]
content = item["content"]
category = item['category']
pid = item['pid']
content = item['content']

# 如果 content 是 string 並且為空,則從 ref_data 裡填入去掉空格的 content
if isinstance(content, str) and content == "":
if isinstance(content, str) and content == '':
if (category, pid) in ref_dict:
item["content"] = ref_dict[(category, pid)]
item['content'] = ref_dict[(category, pid)]

# 將結果寫入 aicup_noocr_sec.json
with open('data/aicup_noocr_sec.json', 'w', encoding='utf-8') as file:
json.dump(noocr_data, file, ensure_ascii=False, indent=4)

print("已完成比對並生成 aicup_noocr_sec.json,並移除轉入的 content 中的空格(如果 content 是字串)")
print('已完成比對並生成 aicup_noocr_sec.json,並移除轉入的 content 中的空格(如果 content 是字串)')
15 changes: 10 additions & 5 deletions data/read_data_noocr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import json
import os

import pdfplumber
from tqdm import tqdm


# 讀取單個PDF文件並返回其文本內容
def read_pdf(pdf_loc):
pdf = pdfplumber.open(pdf_loc)
Expand All @@ -14,24 +16,26 @@ def read_pdf(pdf_loc):
pdf.close()
return pdf_text


# 從指定資料夾載入PDF文件,並根據資料夾名稱設定category
def load_data_by_category(source_path, category):
pdf_files = [f for f in os.listdir(source_path) if f.endswith('.pdf')]
data = []
for file in tqdm(pdf_files):
pid = file.replace('.pdf', '') # 擷取檔案名稱作為pid
content = read_pdf(os.path.join(source_path, file)) # 讀取PDF內文
data.append({"category": category, "pid": pid, "content": content})
data.append({'category': category, 'pid': pid, 'content': content})
return data


# 主程式
def generate_json(output_path):
all_data = []

# 載入不同類別的PDF資料
source_paths = {
"finance": "reference/finance", # finance 資料夾的路徑
"insurance": "reference/insurance" # insurance 資料夾的路徑
'finance': 'reference/finance', # finance 資料夾的路徑
'insurance': 'reference/insurance', # insurance 資料夾的路徑
}

# 遍歷每個類別的資料夾並載入資料
Expand All @@ -43,6 +47,7 @@ def generate_json(output_path):
with open(output_path, 'w', encoding='utf8') as f:
json.dump(all_data, f, ensure_ascii=False, indent=4)


# 設定輸出路徑
output_path = 'data/aicup_noocr.json'
generate_json(output_path)
44 changes: 25 additions & 19 deletions src/db_insert.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import time
import json
from langchain.text_splitter import RecursiveCharacterTextSplitter
import time

import utils.config_log as config_log
import weaviate
from langchain.text_splitter import RecursiveCharacterTextSplitter

config, logger, CONFIG_PATH = config_log.setup_config_and_logging()
config.read(CONFIG_PATH)
Expand All @@ -13,6 +14,7 @@
# Token limit for OpenAI model
TOKEN_LIMIT = 8192


class WeaviateManager:
def __init__(self, classnm):
self.url = wea_url
Expand All @@ -28,7 +30,11 @@ def check_class_exist(self):
'class': self.classnm,
'properties': [
{'name': 'pid', 'dataType': ['text']},
{'name': 'content', 'dataType': ['text'], "tokenization": "gse"}, # `gse` implements the "Jieba" algorithm, which is a popular Chinese text segmentation algorithm.
{
'name': 'content',
'dataType': ['text'],
'tokenization': 'gse',
}, # `gse` implements the "Jieba" algorithm, which is a popular Chinese text segmentation algorithm.
],
'vectorizer': 'text2vec-openai',
'moduleConfig': {
Expand All @@ -51,19 +57,19 @@ def insert_data(self, pid, content):
error_msg = str(e)
# 檢查是否是因為 token 長度過長
if 'maximum context length' in error_msg:
print(f"Content too long for pid: {pid}. Splitting content.")
return "TOO_LONG" # 特殊回傳值表達需要分割
print(f'Content too long for pid: {pid}. Splitting content.')
return 'TOO_LONG' # 特殊回傳值表達需要分割
elif '429' in error_msg:
print(f'Rate limit exceeded, retrying in 5 seconds... (Attempt {attempt + 1}/{max_retries})')
time.sleep(5)
else:
print(f"Unexpected Error for pid: {pid} - {error_msg}")
print(f'Unexpected Error for pid: {pid} - {error_msg}')
return False
except Exception as e:
print(f'Error inserting data for pid: {pid}, category: {self.classnm} - {str(e)}')
return False
# 超過最大重試次數
print(f"Failed to insert data for pid: {pid} after {max_retries} attempts.")
print(f'Failed to insert data for pid: {pid} after {max_retries} attempts.')
return False

def split_and_insert(self, pid, content, category):
Expand All @@ -73,10 +79,10 @@ def split_and_insert(self, pid, content, category):

# 逐段插入分割後的文本,保持相同的 pid 和 category
for idx, part in enumerate(split_content):
print(f"Inserting split content part {idx + 1} for pid: {pid}")
print(f'Inserting split content part {idx + 1} for pid: {pid}')
success = self.insert_data(pid, part)
if not success:
failed_records.append({"pid": pid, "category": category})
failed_records.append({'pid': pid, 'category': category})


if __name__ == '__main__':
Expand All @@ -90,32 +96,32 @@ def split_and_insert(self, pid, content, category):
pid = item['pid']
content = item['content']

if category == "faq":
classnm = "faqdev"
if category == 'faq':
classnm = 'faqdev'
content_str = json.dumps(content, ensure_ascii=False, indent=4)
elif category == "insurance":
classnm = "insurancedev"
elif category == 'insurance':
classnm = 'insurancedev'
content_str = content
elif category == "finance":
classnm = "financedev"
elif category == 'finance':
classnm = 'financedev'
content_str = json.dumps(content, ensure_ascii=False, indent=4) if isinstance(content, dict) else content
else:
print("Unknown category, skipping item.")
print('Unknown category, skipping item.')
continue

manager = WeaviateManager(classnm)
result = manager.insert_data(pid, content_str)

# 如果內容過長需要切割
if result == "TOO_LONG":
if result == 'TOO_LONG':
manager.split_and_insert(pid, content_str, category)
elif not result: # 如果失敗且非長度問題
failed_records.append({"pid": pid, "category": category})
failed_records.append({'pid': pid, 'category': category})

# 將失敗的資料寫入 JSON 檔案
if failed_records:
with open('failed_imports.json', 'w', encoding='utf-8') as f:
json.dump(failed_records, f, ensure_ascii=False, indent=4)
print("Failed records have been written to 'failed_imports.json'")
else:
print("All records imported successfully.")
print('All records imported successfully.')
10 changes: 4 additions & 6 deletions src/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_restx import Api, Resource, fields

# from utils.weaviate_op import search_do
from utils.weaviatexreranker import search_do
from werkzeug.security import check_password_hash, generate_password_hash
Expand Down Expand Up @@ -45,7 +46,7 @@ def verify_password(username, password):
'qid': fields.Integer(required=True, description='qid of the question'),
'source': fields.List(fields.Integer, required=True, description='source of the question'),
'query': fields.String(required=True, description='The message to the chatbot'),
'category': fields.String(required=True, description='The category of the question')
'category': fields.String(required=True, description='The category of the question'),
},
)

Expand Down Expand Up @@ -77,7 +78,7 @@ def post(self):
# "query": "匯款銀行及中間行所收取之相關費用由誰負擔?",
# "category": "insurance"
# },

alpha = 0.5

if not question:
Expand All @@ -87,10 +88,7 @@ def post(self):
else:
try:
response = search_do(question, category, source, alpha)
response = {
'qid': qid,
'retrieve': int(response)
}
response = {'qid': qid, 'retrieve': int(response)}

response = jsonify(response)

Expand Down
36 changes: 17 additions & 19 deletions src/tools/automate.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,49 @@
import requests
import json
import time # Import time module for timing

import requests

# Load questions from the JSON file
with open('data/questions_example.json', 'r', encoding='utf-8') as file:
with open('data/questions_example.json', encoding='utf-8') as file:
questions = json.load(file)['questions']

output_data = {"answers": []} # Initialize output format with "answers" array
output_data = {'answers': []} # Initialize output format with "answers" array

url = "http://127.0.0.1:5000/api/chat"
url = 'http://127.0.0.1:5000/api/chat'

total_start_time = time.time() # Start timing for the entire process

for question in questions:
question_start_time = time.time() # Start timing for each question

# Send POST request
response = requests.post(url, json=question)

if response.status_code == 200:
response_json = response.json()

# Extract qid and retrieve from the API response
qid = question.get("qid") # Assuming each question has a unique "qid" field
retrieve = response_json.get("retrieve")
qid = question.get('qid') # Assuming each question has a unique "qid" field
retrieve = response_json.get('retrieve')

# Append formatted result to the answers array
output_data["answers"].append({
"qid": qid,
"retrieve": retrieve
})
print("成功取得 JSON:", response_json)
output_data['answers'].append({'qid': qid, 'retrieve': retrieve})
print('成功取得 JSON:', response_json)
else:
print("請求失敗,狀態碼:", response.status_code)
print('請求失敗,狀態碼:', response.status_code)

# Calculate and print time for each question
question_end_time = time.time()
question_duration = question_end_time - question_start_time
print(f"QID: {qid} - 花費時間: {question_duration:.2f}")
print(f'QID: {qid} - 花費時間: {question_duration:.2f}')

# Calculate and print total time
total_end_time = time.time()
total_duration = total_end_time - total_start_time
print(f"全部題目處理完成,總共花費時間: {total_duration:.2f}")
print(f'全部題目處理完成,總共花費時間: {total_duration:.2f}')

# Save the output data to a new JSON file
with open('data/pred_retrieve.json', 'w', encoding='utf-8') as output_file:
json.dump(output_data, output_file, ensure_ascii=False, indent=4)

print("合併輸出已保存到 pred_retrieve.json 文件中。")
print('合併輸出已保存到 pred_retrieve.json 文件中。')
26 changes: 13 additions & 13 deletions src/tools/checkans.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,39 @@
from collections import defaultdict

# Load ground truth data
with open('data/ground_truths_example.json', 'r') as f:
ground_truths = json.load(f)["ground_truths"]
with open('data/ground_truths_example.json') as f:
ground_truths = json.load(f)['ground_truths']

# Load predicted data with the new format
with open('data/pred_retrieve.json', 'r') as f:
pred_retrieves = json.load(f)["answers"]
with open('data/pred_retrieve.json') as f:
pred_retrieves = json.load(f)['answers']

# Create a dictionary from predictions for easy lookup
pred_dict = {item["qid"]: item["retrieve"] for item in pred_retrieves}
pred_dict = {item['qid']: item['retrieve'] for item in pred_retrieves}

# Initialize counters and data structures
incorrect_qids = []
correct_count = 0
category_counts = defaultdict(lambda: {"correct": 0, "total": 0})
category_counts = defaultdict(lambda: {'correct': 0, 'total': 0})

# Compare predictions to ground truth
for ground in ground_truths:
qid = ground["qid"]
category = ground["category"]
correct_retrieve = ground["retrieve"]
qid = ground['qid']
category = ground['category']
correct_retrieve = ground['retrieve']
predicted_retrieve = pred_dict.get(qid)

if predicted_retrieve == correct_retrieve:
correct_count += 1
category_counts[category]["correct"] += 1
category_counts[category]['correct'] += 1
else:
incorrect_qids.append(qid)

category_counts[category]["total"] += 1
category_counts[category]['total'] += 1

# Print results
print("錯誤的題目 QID:", incorrect_qids)
print(f"總正確題數: {correct_count} / {len(ground_truths)}")
print('錯誤的題目 QID:', incorrect_qids)
print(f'總正確題數: {correct_count} / {len(ground_truths)}')

for category, counts in category_counts.items():
print(f"類別 {category}: {counts['correct']} / {counts['total']}")
Loading

0 comments on commit 17c2a9f

Please sign in to comment.