-
Notifications
You must be signed in to change notification settings - Fork 33
/
demo_open.py
99 lines (89 loc) · 2.83 KB
/
demo_open.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python
# coding=utf-8
import tensorflow as tf
import bottle
from bottle import route, run, static_file
import threading
import numpy as np
import os
from soqal import SOQAL
from time import sleep
import sys
import pickle
sys.path.append(os.path.abspath("retriever"))
from retriever.GoogleSearchRetriever import ApiGoogleSearchRetriever
from retriever.TfidfRetriever import HierarchicalTfidf
from retriever.TfidfRetriever import TfidfRetriever
sys.path.append(os.path.abspath("bert"))
from bert.Bert_model import BERT_model
'''
This file is taken and modified from R-Net by Minsangkim142
https://github.com/minsangkim142/R-net
'''
app = bottle.Bottle()
query = []
response = ""
my_module = os.path.abspath(__file__)
parent_dir = os.path.dirname(my_module)
static_dir = os.path.join(parent_dir, 'static')
@app.get("/")
def home():
with open('demo_open.html', encoding='utf-8') as fl:
html = fl.read()
return html
@app.get('/static/<filename>')
def server_static(filename):
return static_file(filename, root=static_dir)
@app.post('/answer')
def answer():
question = bottle.request.json['question']
print("received question: {}".format(question))
# if not passage or not question:
# exit()
global query, response
query = question
if query != "":
while not response:
sleep(0.1)
else:
response = "Please write a question"
print("received response: {}".format(response))
response_ = {"answer": response}
response = []
return response_
class Demo(object):
def __init__(self, model, config):
self.model = model
run_event = threading.Event()
run_event.set()
self.close_thread = True
threading.Thread(target=self.demo_backend).start()
app.run(port=9999, host='0.0.0.0')
try:
while 1:
sleep(.1)
except KeyboardInterrupt:
print("Closing server...")
self.close_thread = False
def demo_backend(self):
global query, response
while self.close_thread:
sleep(0.1)
if query:
response = self.model.ask(query)
query = []
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', help='Path to bert_config.json', required=True)
parser.add_argument('-v', '--vocab', help='Path to vocab.txt', required=True)
parser.add_argument('-o', '--output', help='Directory of model outputs', required=True)
parser.add_argument('-r', '--ret-path', help='Retriever Path', required=True)
def main():
args = parser.parse_args()
base_r = pickle.load(open(args.ret_path, "rb"))
ret = HierarchicalTfidf(base_r, 50, 50)
red = BERT_model(args.config, args.vocab, args.output)
AI = SOQAL(ret, red, 0.999)
demo = Demo(AI, None)
if __name__ == "__main__":
main()