-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
executable file
·89 lines (71 loc) · 3.37 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
import ast
import argparse
import sys
import numpy as np
# fixed setting
MAX_NUM_WORD_LARGE_ENOUGH=20000 # don't allow collection larger than this
NUM_WORD_HIST_BIN_WIDTH=100
RUNFILE_DIR="runs"
QA_PATH="DensePhrases/densephrases-data/open-qa/nq-open/test_preprocessed.json"
def eval(args):
# read q&a pairs
qa_pair_by_qid = {}
with open(QA_PATH, 'r') as fr:
data = json.load(fr)
for sample in data['data']:
qid, query, answers = sample['id'], sample['question'], sample['answers']
qa_pair_by_qid[qid] = {'query': query, 'answers': answers}
num_query = len(data['data'])
# evaluate recall with varying collection length
runfile_path=f"{RUNFILE_DIR}/{args.runfile_name}"
num_bin = int(MAX_NUM_WORD_LARGE_ENOUGH/NUM_WORD_HIST_BIN_WIDTH) + 1
recall_by_collection_len_per_sample = np.zeros((num_query, num_bin)) # measure recall by `macro` fashion
max_word_count_sum = 0 # will be used to truncate unnecessary bin_idx
with open(runfile_path, 'r') as fr:
for q_idx, line in enumerate(fr):
qid, retrieved = line.split('\t')
retrieved = ast.literal_eval(retrieved)
K = len(retrieved)
ans_list = qa_pair_by_qid[qid]["answers"]
num_ans_all = len(ans_list)
ans_hit_check = [False]*num_ans_all
word_count_sum = 0
for k in range(K):
text = retrieved[k]
word_count = len(text.split(' '))
word_count_sum += word_count
bin_idx = int(word_count_sum/NUM_WORD_HIST_BIN_WIDTH)
# check whether text include answers or not
for l in range(num_ans_all):
ans = ans_list[l]
if ans in text:
ans_hit_check[l] = True
# calculate recall per sample
recall_by_collection_len_per_sample[q_idx, bin_idx] = sum(ans_hit_check)/num_ans_all
max_word_count_sum = max(max_word_count_sum, word_count_sum)
# interpolation intermediate bin indices (making monotonic increasing)
prev_non_zero = 0
for b_idx in range(num_bin):
if recall_by_collection_len_per_sample[q_idx, b_idx] > 0:
prev_non_zero = recall_by_collection_len_per_sample[q_idx, b_idx]
else:
recall_by_collection_len_per_sample[q_idx, b_idx] = prev_non_zero
# prune unused bin indices
bin_idx_max = int(max_word_count_sum/NUM_WORD_HIST_BIN_WIDTH)
recall_by_collection_len_per_sample = recall_by_collection_len_per_sample[:, :bin_idx_max+1]
recall_by_collection_len = np.mean(recall_by_collection_len_per_sample, axis=0)
print(recall_by_collection_len)
# get mean average reacall (mAR)
mAR = sum(recall_by_collection_len)/len(recall_by_collection_len)
print(f'mean average recall = {mAR}')
if __name__ == "__main__":
# parse arguments
parser = argparse.ArgumentParser(description='Evaluate recall with varying collection length.')
parser.add_argument('--runfile_name', type=str, default="run.tsv",
help="output runfile name which indluces query id and retrieved collection")
args = parser.parse_args()
# to prevent collision with DensePhrase native argparser
sys.argv = [sys.argv[0]]
# run
eval(args)