forked from tech-srl/code2seq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_code2seq.py
89 lines (79 loc) · 2.2 KB
/
_code2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from argparse import ArgumentParser
import numpy as np
import tensorflow as tf
from config import Config
from interactive_predict import InteractivePredictor
from model import Model
def get_args():
parser = ArgumentParser()
parser.add_argument(
"-d",
"--data",
dest="data_path",
help="path to preprocessed dataset",
required=False,
)
parser.add_argument(
"-te",
"--test",
dest="test_path",
help="path to test file",
metavar="FILE",
required=False,
)
parser.add_argument(
"-s",
"--save_prefix",
dest="save_path_prefix",
help="path to save file",
metavar="FILE",
required=False,
)
parser.add_argument(
"-l",
"--load",
dest="load_path",
help="path to saved file",
metavar="FILE",
required=False,
)
parser.add_argument(
"--release",
action="store_true",
help="if specified and loading a trained model, release the loaded model for a smaller model "
"size.",
)
parser.add_argument("--predict", action="store_true")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--seed", type=int, default=239)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
np.random.seed(args.seed)
tf.compat.v1.set_random_seed(args.seed)
if args.debug:
config = Config.get_debug_config(args)
else:
config = Config.get_default_config(args)
model = Model(config)
print("Created model")
if config.TRAIN_PATH:
model.train()
if config.TEST_PATH and not args.data_path:
results, precision, recall, f1, rouge = model.evaluate()
print("Accuracy: " + str(results))
print(
"Precision: "
+ str(precision)
+ ", recall: "
+ str(recall)
+ ", F1: "
+ str(f1)
)
print("Rouge: ", rouge)
if args.predict:
predictor = InteractivePredictor(config, model)
predictor.predict()
if args.release and args.load_path:
model.evaluate(release=True)
model.close_session()