-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2816bc5
commit ce09d78
Showing
16 changed files
with
54,402 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# secret ini | ||
config.ini | ||
reference/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
# python3 bm25_retrieve.py --question_path data/questions_example.json --source_path reference --output_path data/pred_retrieve.json | ||
|
||
import os | ||
import json | ||
import argparse | ||
|
||
from tqdm import tqdm | ||
import jieba # 用於中文文本分詞 | ||
import pdfplumber # 用於從PDF文件中提取文字的工具 | ||
from rank_bm25 import BM25Okapi # 使用BM25演算法進行文件檢索 | ||
|
||
|
||
# 載入參考資料,返回一個字典,key為檔案名稱,value為PDF檔內容的文本 | ||
def load_data(source_path): | ||
masked_file_ls = os.listdir(source_path) # 獲取資料夾中的檔案列表 | ||
corpus_dict = {int(file.replace('.pdf', '')): read_pdf(os.path.join(source_path, file)) for file in tqdm(masked_file_ls)} # 讀取每個PDF文件的文本,並以檔案名作為鍵,文本內容作為值存入字典 | ||
return corpus_dict | ||
|
||
|
||
# 讀取單個PDF文件並返回其文本內容 | ||
def read_pdf(pdf_loc, page_infos: list = None): | ||
pdf = pdfplumber.open(pdf_loc) # 打開指定的PDF文件 | ||
|
||
# TODO: 可自行用其他方法讀入資料,或是對pdf中多模態資料(表格,圖片等)進行處理 | ||
|
||
# 如果指定了頁面範圍,則只提取該範圍的頁面,否則提取所有頁面 | ||
pages = pdf.pages[page_infos[0]:page_infos[1]] if page_infos else pdf.pages | ||
pdf_text = '' | ||
for _, page in enumerate(pages): # 迴圈遍歷每一頁 | ||
text = page.extract_text() # 提取頁面的文本內容 | ||
if text: | ||
pdf_text += text | ||
pdf.close() # 關閉PDF文件 | ||
|
||
return pdf_text # 返回萃取出的文本 | ||
|
||
|
||
# 根據查詢語句和指定的來源,檢索答案 | ||
def BM25_retrieve(qs, source, corpus_dict): | ||
filtered_corpus = [corpus_dict[int(file)] for file in source] | ||
|
||
# [TODO] 可自行替換其他檢索方式,以提升效能 | ||
|
||
tokenized_corpus = [list(jieba.cut_for_search(doc)) for doc in filtered_corpus] # 將每篇文檔進行分詞 | ||
bm25 = BM25Okapi(tokenized_corpus) # 使用BM25演算法建立檢索模型 | ||
tokenized_query = list(jieba.cut_for_search(qs)) # 將查詢語句進行分詞 | ||
ans = bm25.get_top_n(tokenized_query, list(filtered_corpus), n=1) # 根據查詢語句檢索,返回最相關的文檔,其中n為可調整項 | ||
a = ans[0] | ||
# 找回與最佳匹配文本相對應的檔案名 | ||
res = [key for key, value in corpus_dict.items() if value == a] | ||
return res[0] # 回傳檔案名 | ||
|
||
|
||
if __name__ == "__main__": | ||
# 使用argparse解析命令列參數 | ||
parser = argparse.ArgumentParser(description='Process some paths and files.') | ||
parser.add_argument('--question_path', type=str, required=True, help='讀取發布題目路徑') # 問題文件的路徑 | ||
parser.add_argument('--source_path', type=str, required=True, help='讀取參考資料路徑') # 參考資料的路徑 | ||
parser.add_argument('--output_path', type=str, required=True, help='輸出符合參賽格式的答案路徑') # 答案輸出的路徑 | ||
|
||
args = parser.parse_args() # 解析參數 | ||
|
||
answer_dict = {"answers": []} # 初始化字典 | ||
|
||
with open(args.question_path, 'rb') as f: | ||
qs_ref = json.load(f) # 讀取問題檔案 | ||
|
||
source_path_insurance = os.path.join(args.source_path, 'insurance') # 設定參考資料路徑 | ||
corpus_dict_insurance = load_data(source_path_insurance) | ||
|
||
source_path_finance = os.path.join(args.source_path, 'finance') # 設定參考資料路徑 | ||
corpus_dict_finance = load_data(source_path_finance) | ||
|
||
with open(os.path.join(args.source_path, 'faq/pid_map_content.json'), 'rb') as f_s: | ||
key_to_source_dict = json.load(f_s) # 讀取參考資料文件 | ||
key_to_source_dict = {int(key): value for key, value in key_to_source_dict.items()} | ||
|
||
for q_dict in qs_ref['questions']: | ||
if q_dict['category'] == 'finance': | ||
# 進行檢索 | ||
retrieved = BM25_retrieve(q_dict['query'], q_dict['source'], corpus_dict_finance) | ||
# 將結果加入字典 | ||
answer_dict['answers'].append({"qid": q_dict['qid'], "retrieve": retrieved}) | ||
|
||
elif q_dict['category'] == 'insurance': | ||
retrieved = BM25_retrieve(q_dict['query'], q_dict['source'], corpus_dict_insurance) | ||
answer_dict['answers'].append({"qid": q_dict['qid'], "retrieve": retrieved}) | ||
|
||
elif q_dict['category'] == 'faq': | ||
corpus_dict_faq = {key: str(value) for key, value in key_to_source_dict.items() if key in q_dict['source']} | ||
retrieved = BM25_retrieve(q_dict['query'], q_dict['source'], corpus_dict_faq) | ||
answer_dict['answers'].append({"qid": q_dict['qid'], "retrieve": retrieved}) | ||
|
||
else: | ||
raise ValueError("Something went wrong") # 如果過程有問題,拋出錯誤 | ||
|
||
# 將答案字典保存為json文件 | ||
with open(args.output_path, 'w', encoding='utf8') as f: | ||
json.dump(answer_dict, f, ensure_ascii=False, indent=4) # 儲存檔案,確保格式和非ASCII字符 | ||
|
||
# import os | ||
# import json | ||
# import argparse | ||
# from tqdm import tqdm | ||
# import jieba # 用於中文文本分詞 | ||
# from rank_bm25 import BM25Okapi # 使用BM25演算法進行文件檢索 | ||
|
||
# # 載入參考資料,返回一個字典,key為pid,value為文本內容 | ||
# def load_data_from_json(json_path): | ||
# with open(json_path, 'r', encoding='utf-8') as f: | ||
# data = json.load(f) | ||
# corpus_dict = {} | ||
# for entry in data: | ||
# category = entry['category'] | ||
# pid = entry['pid'] | ||
# content = entry['content'] | ||
|
||
# # 如果是 FAQ 類別,將問題與答案合併為文本 | ||
# if category == "faq": | ||
# text = content['question'] + " " + " ".join(content['answers']) | ||
# else: | ||
# text = content # 其他類別直接使用 content 內容 | ||
|
||
# corpus_dict[int(pid)] = text | ||
# return corpus_dict | ||
|
||
# # 根據查詢語句和指定的來源,檢索答案 | ||
# def BM25_retrieve(qs, source, corpus_dict): | ||
# filtered_corpus = [corpus_dict[int(file)] for file in source] | ||
|
||
# # [TODO] 可自行替換其他檢索方式,以提升效能 | ||
|
||
# tokenized_corpus = [list(jieba.cut_for_search(doc)) for doc in filtered_corpus] # 將每篇文檔進行分詞 | ||
# bm25 = BM25Okapi(tokenized_corpus) # 使用BM25演算法建立檢索模型 | ||
# tokenized_query = list(jieba.cut_for_search(qs)) # 將查詢語句進行分詞 | ||
# ans = bm25.get_top_n(tokenized_query, list(filtered_corpus), n=1) # 根據查詢語句檢索,返回最相關的文檔,其中n為可調整項 | ||
# a = ans[0] | ||
# # 找回與最佳匹配文本相對應的檔案名 | ||
# res = [key for key, value in corpus_dict.items() if value == a] | ||
# return res[0] # 回傳檔案名 | ||
|
||
# if __name__ == "__main__": | ||
# parser = argparse.ArgumentParser(description='Process question path and output path.') | ||
# parser.add_argument('--question_path', type=str, required=True, help='讀取發布題目路徑') # 問題文件的路徑 | ||
# parser.add_argument('--output_path', type=str, required=True, help='輸出符合參賽格式的答案路徑') # 答案輸出的路徑 | ||
|
||
# args = parser.parse_args() # 解析參數 | ||
|
||
# answer_dict = {"answers": []} # 初始化字典 | ||
|
||
# with open(args.question_path, 'rb') as f: | ||
# qs_ref = json.load(f) # 讀取問題檔案 | ||
|
||
# corpus_dict = load_data_from_json('data/aicup_ref.json') # 讀取固定的JSON文件 | ||
|
||
# for q_dict in qs_ref['questions']: | ||
# category = q_dict['category'] | ||
# retrieved = BM25_retrieve(q_dict['query'], q_dict['source'], corpus_dict) | ||
|
||
# # 將結果加入字典 | ||
# answer_dict['answers'].append({"qid": q_dict['qid'], "retrieve": retrieved}) | ||
|
||
# # 將答案字典保存為json文件 | ||
# with open(args.output_path, 'w', encoding='utf8') as f: | ||
# json.dump(answer_dict, f, ensure_ascii=False, indent=4) # 儲存檔案,確保格式和非ASCII字符 | ||
|
||
# # import os | ||
# # import json | ||
# # import argparse | ||
# # import jieba | ||
# # from tqdm import tqdm | ||
# # from rank_bm25 import BM25Okapi | ||
# # from openai.embeddings_utils import get_embedding | ||
# # from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
# # import openai | ||
|
||
# # openai.api_key = "sk-proj-SkGnAt0RNzntX3OyoLHmT3BlbkFJ0P0wKk5pPpQOKXuP0zpy" | ||
|
||
# # MAX_TOKENS = 5000 | ||
# # EMBEDDING_MODEL = "text-embedding-3-large" | ||
|
||
# # def load_data(file_path): | ||
# # with open(file_path, 'r', encoding='utf-8') as f: | ||
# # data = json.load(f) | ||
# # corpus_dict = {} | ||
# # for item in data: | ||
# # pid = int(item['pid']) | ||
# # if item['category'] == 'faq': | ||
# # content = item['content']['question'] + " " + " ".join(item['content']['answers']) | ||
# # else: | ||
# # content = item['content'] | ||
# # corpus_dict[pid] = content | ||
# # return corpus_dict | ||
|
||
# # def chunk_text(text, max_tokens=MAX_TOKENS): | ||
# # text_splitter = RecursiveCharacterTextSplitter(chunk_size=max_tokens, chunk_overlap=500) | ||
# # return text_splitter.split_text(text) | ||
|
||
# # def get_text_embedding(text): | ||
# # response = openai.Embedding.create( | ||
# # input=text.replace("\n", " "), | ||
# # model=EMBEDDING_MODEL | ||
# # ) | ||
# # return response['data'][0]['embedding'] | ||
|
||
# # def combined_retrieve(qs, source, corpus_dict, weight_bm25=0.5, weight_embedding=0.5): | ||
# # filtered_corpus = [corpus_dict[int(pid)] for pid in source] | ||
|
||
# # # BM25 retrieval | ||
# # tokenized_corpus = [list(jieba.cut_for_search(doc)) for doc in filtered_corpus] | ||
# # bm25 = BM25Okapi(tokenized_corpus) | ||
# # tokenized_query = list(jieba.cut_for_search(qs)) | ||
# # bm25_scores = bm25.get_scores(tokenized_query) | ||
|
||
# # # Embedding-based retrieval | ||
# # query_embedding = get_text_embedding(qs) | ||
# # embedding_scores = [] | ||
# # for doc in filtered_corpus: | ||
# # chunks = chunk_text(doc) # Chunk text if it exceeds token limit | ||
# # chunk_embeddings = [get_text_embedding(chunk) for chunk in chunks] | ||
# # max_chunk_score = max([openai.embeddings_utils.cosine_similarity(query_embedding, emb) for emb in chunk_embeddings]) | ||
# # embedding_scores.append(max_chunk_score) | ||
|
||
# # # Combined score | ||
# # combined_scores = [weight_bm25 * bm25_score + weight_embedding * embed_score | ||
# # for bm25_score, embed_score in zip(bm25_scores, embedding_scores)] | ||
# # best_idx = combined_scores.index(max(combined_scores)) | ||
# # return source[best_idx] # Return best matching pid from source | ||
|
||
# # if __name__ == "__main__": | ||
# # parser = argparse.ArgumentParser(description='Process some paths and files.') | ||
# # parser.add_argument('--question_path', type=str, required=True, help='讀取發布題目路徑') | ||
# # parser.add_argument('--source_path', type=str, required=True, help='讀取參考資料路徑') | ||
# # parser.add_argument('--output_path', type=str, required=True, help='輸出符合參賽格式的答案路徑') | ||
|
||
# # args = parser.parse_args() | ||
# # answer_dict = {"answers": []} | ||
|
||
# # with open(args.question_path, 'r', encoding='utf-8') as f: | ||
# # qs_ref = json.load(f) | ||
|
||
# # corpus_dict = load_data(os.path.join(args.source_path, 'aicup_ref.json')) | ||
|
||
# # for q_dict in tqdm(qs_ref['questions'], desc="Processing questions"): | ||
# # category = q_dict['category'] | ||
# # source = q_dict['source'] | ||
# # query = q_dict['query'] | ||
# # retrieved = combined_retrieve(query, source, corpus_dict) | ||
# # answer_dict['answers'].append({"qid": q_dict['qid'], "retrieve": retrieved}) | ||
|
||
# # with open(args.output_path, 'w', encoding='utf-8') as f: | ||
# # json.dump(answer_dict, f, ensure_ascii=False, indent=4) |
Oops, something went wrong.