This repository has been archived by the owner on Aug 12, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclassify.py
37 lines (31 loc) · 1.49 KB
/
classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import argparse
import pickle
from features import get_features
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Infer citation style based on the reference string')
parser.add_argument('-m', '--model', help='model file', type=str,
default='models/default-model')
parser.add_argument('-r', '--reference', help='reference string',
type=str)
parser.add_argument('-i', '--input', help='file with reference strings',
type=str)
parser.add_argument('-o', '--output', help='output file', type=str)
args = parser.parse_args()
with open(args.model, 'rb') as file:
model = pickle.load(file)
count_vectorizer, tfidf_transformer, model = tuple(model)
if args.reference:
_, _, features = get_features([args.reference],
count_vectorizer=count_vectorizer,
tfidf_transformer=tfidf_transformer)
print(model.predict(features)[0])
if args.input:
with open(args.input, 'r') as file:
ref_strings = [line.strip() for line in file]
_, _, features = get_features(ref_strings,
count_vectorizer=count_vectorizer,
tfidf_transformer=tfidf_transformer)
styles = model.predict(features)
with open(args.output, 'w') as file:
file.writelines([s + '\n' for s in styles])