-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_test.py
22 lines (18 loc) · 1.09 KB
/
main_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import argparse
import os
import main_constants
from retrieval.neural.export_dataset import evaluate_test_set
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str, required=True,
choices=['overlap', 'uni_tfidf', 'bi_tfidf', 'prf_lm', 'max_pool_bllr_pw', 'max_pool_llr_pw',
'mean_pool_bllr_pw', 'mean_pool_llr_pw', 'gru_llr_pw',
'max_pool_llr_features_pw', 'max_pool_llr_embeddings_pw', 'max_pool_llr_full_pw', 'all']
, help='Which model to create a plot for.')
parser.add_argument('-o', '--outputdir', type=str, default='test',
choices=['report', 'show', 'save', 'evaluation'], help='Directory to save the hotpot files in')
args, _ = parser.parse_known_args()
for model in os.listdir(main_constants.MODEL_BASE_DIR):
if os.path.isfile(main_constants.L2R_BEST_MODEL.format(model)):
if args.model == 'all' or model == args.model:
evaluate_test_set(model, args.outputdir)