-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathbuild_index.py
264 lines (244 loc) · 9.07 KB
/
build_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import argparse
import sys
import cPickle as pickle
import sqlite3
import codecs
from datetime import datetime
from tree import Tree, SymbolStats
import json
import re
import struct
import os
import setlib.pytset as pytset
import zlib
import itertools
import os.path
import traceback
ID,FORM,LEMMA,PLEMMA,POS,PPOS,FEAT,PFEAT,HEAD,PHEAD,DEPREL,PDEPREL=range(12)
symbs=re.compile(ur"[^A-Za-z0-9_]",re.U)
def prepare_tables(conn):
build=\
"""
CREATE TABLE graph (
graph_id INTEGER,
token_count INTEGER,
conllu_data_compressed BLOB,
conllu_comment_compressed BLOB
);
CREATE TABLE token_index (
norm INTEGER,
token TEXT,
graph_id INTEGER,
token_set BLOB
);
CREATE TABLE lemma_index (
norm INTEGER,
lemma TEXT,
graph_id INTEGER,
token_set BLOB
);
CREATE TABLE tag_index (
graph_id INTEGER,
tag TEXT,
token_set BLOB
);
CREATE TABLE rel (
graph_id INTEGER,
dtype TEXT,
token_gov_map BLOB,
token_dep_map BLOB
);
"""
for q in build.split(";"):
q=q.strip()
if q:
print q
conn.execute(q)
conn.commit()
def build_indices(conn):
build=\
"""
CREATE UNIQUE INDEX tok_gid ON token_index(token,graph_id,norm);
CREATE UNIQUE INDEX lemma_gid ON lemma_index(lemma,graph_id,norm);
CREATE UNIQUE INDEX gid_tag ON tag_index(graph_id,tag);
CREATE INDEX tag_gid ON tag_index(tag,graph_id);
CREATE UNIQUE INDEX gid ON graph(graph_id);
CREATE UNIQUE INDEX gid_dtype ON rel(graph_id,dtype);
analyze;
"""
for q in build.split(";"):
if q.strip():
print q
conn.execute(q)
conn.commit()
def read_conll(inp,maxsent=0):
""" Read conll format file and yield one sentence at a time as a list of lists of columns. If inp is a string it will be interpreted as fi
lename, otherwise as open file for reading in unicode"""
if isinstance(inp,basestring):
f=codecs.open(inp,u"rt",u"utf-8")
else:
f=codecs.getreader("utf-8")(inp) # read inp directly
count=0
sent=[]
comments=[]
for line in f:
line=line.strip()
if not line:
if sent:
count+=1
yield sent, comments
if maxsent!=0 and count>=maxsent:
break
sent=[]
comments=[]
elif line.startswith(u"#"):
if sent:
raise ValueError("Missing newline after sentence")
comments.append(line)
continue
else:
cols=line.split(u"\t")
if cols[0].isdigit() or u"." in cols[0]:
sent.append(cols)
else:
if sent:
yield sent, comments
if isinstance(inp,basestring):
f.close() #Close it if you opened it
def add_doc_comments(sents):
"""
in goes an iterator over sent,comments pairs
out goes an iterator with comments enriched by URLs
"""
urlRe=re.compile(ur'url="(.*?)"',re.U)
doc_counter,sent_in_doc_counter=-1,0
current_url=None
for sent,comments in sents:
###PB3 style URLs
if len(sent)==1 and sent[0][1].startswith("####FIPBANK-BEGIN-MARKER:"):
current_url=sent[0][1].split(u":",1)[1]
doc_counter+=1
sent_in_doc_counter=0
continue
###Todo: PB4 style URL comments
for c in comments:
if c.startswith(u"###C:</doc"):
current_url=None
elif c.startswith(u"###C:<doc"):
match=urlRe.search(c)
if not match: #WHoa!
print >> sys.stderr, "Missing url", c.encode("utf-8")
else:
current_url=match.group(1)
###C:<doc id="3-1954112" length="1k-10k" crawl_date="2014-07-26" url="http://parolanasema.blogspot.fi/2013/02/paris-paris-maison-objet-messut-osa-1.html" langdiff="0.37">
if current_url is not None:
comments.append(u"# URL: "+current_url)
comments.append(u"# DOC/SENT: %d/%d"%(doc_counter,sent_in_doc_counter))
yield sent,comments
sent_in_doc_counter+=1
def serialize_as_tset_array(tree_len,sets):
"""
tree_len -> length of the tree to be serialized
sets: array of tree_len sets, each set holding the indices of the elements
"""
indices=[]
for set_idx,s in enumerate(sets):
for item in s:
indices.append(struct.pack("@HH",set_idx,item))
#print "IDXs", len(indices)
res=struct.pack("@H",tree_len)+("".join(indices))
return res
rsub=re.compile(ur"[#-]",re.U)
def fill_db(conn,src_data,stats):
"""
`src_data` - iterator over sentences -result of read_conll()
"""
counter=0
for sent_idx,(sent,comments) in enumerate(add_doc_comments(src_data)):
if len(sent)>256:
print >> sys.stderr, "skipping length", len(sent)
sys.stderr.flush()
continue
counter+=1
t=Tree.from_conll(comments,sent,stats)
conn.execute('INSERT INTO graph VALUES(?,?,?,?)', [sent_idx,len(sent),buffer(zlib.compress(t.conllu.encode("utf-8"))),buffer(zlib.compress(t.comments.encode("utf-8")))])
for token, token_set in t.tokens.iteritems():
conn.execute('INSERT INTO token_index VALUES(?,?,?,?)', [0,token,sent_idx,buffer(token_set.tobytes())])
for lemma, token_set in t.lemmas.iteritems():
conn.execute('INSERT INTO lemma_index VALUES(?,?,?,?)', [0,lemma,sent_idx,buffer(token_set.tobytes())])
for token, token_set in t.normtokens.iteritems():
conn.execute('INSERT INTO token_index VALUES(?,?,?,?)', [1,token,sent_idx,buffer(token_set.tobytes())])
for lemma, token_set in t.normlemmas.iteritems():
conn.execute('INSERT INTO lemma_index VALUES(?,?,?,?)', [1,lemma,sent_idx,buffer(token_set.tobytes())])
for tag, token_set in t.tags.iteritems():
conn.execute('INSERT INTO tag_index VALUES(?,?,?)', [sent_idx,tag,buffer(token_set.tobytes())])
for dtype, (govs,deps) in t.rels.iteritems():
ne_g=[x for x in govs if x]
ne_d=[x for x in deps if x]
assert ne_g and ne_d
gov_set=pytset.PyTSet(len(sent),(idx for idx,s in enumerate(govs) if s))
dep_set=pytset.PyTSet(len(sent),(idx for idx,s in enumerate(deps) if s))
try:
conn.execute('INSERT INTO rel VALUES(?,?,?,?)', [sent_idx,dtype,buffer(serialize_as_tset_array(len(sent),govs)),buffer(serialize_as_tset_array(len(sent),deps))])
except struct.error:
for l in sent:
print >> sys.stderr, l
print >> sys.stderr
if sent_idx%10000==0:
print str(datetime.now()), sent_idx
sys.stdout.flush()
if sent_idx%10000==0:
conn.commit()
conn.commit()
return counter
def save_stats(stats):
try:
if os.path.exists(os.path.join(args.dir,"symbols.json")):
stats.update_with_json(os.path.join(args.dir,"symbols.json"))
except:
traceback.print_exc()
stats.save_json(os.path.join(args.dir,"symbols.json"))
def skip(items,skip):
counter=0
for i in items:
counter+=1
if counter<=skip:
if counter%1000000==0:
print >> sys.stderr, "Skipped ", counter
continue
yield i
if __name__=="__main__":
parser = argparse.ArgumentParser(description='Train')
parser.add_argument('-d', '--dir', required=True, help='Directory name to save the index. Will be wiped and recreated.')
parser.add_argument('-p', '--prefix', default="trees", help='Prefix name of the database files. Default: %(default)s')
parser.add_argument('--skip', type=int, default=0, help='How many sentences to skip from stdin before starting? default: %(default)d')
parser.add_argument('--max', type=int, default=0, help='How many sentences to read from stdin? 0 for all. default: %(default)d')
parser.add_argument('--wipe', default=False, action="store_true", help='Wipe the target directory before building the index.')
args = parser.parse_args()
# gather_tbl_names(codecs.getreader("utf-8")(sys.stdin))
os.system("mkdir -p "+args.dir)
if args.wipe:
print >> sys.stderr, "Wiping target"
cmd="rm -f %s/*.db %s/*.db-journal %s/symbols.json"%(args.dir, args.dir,args.dir)
print >> sys.stderr, cmd
os.system(cmd)
src_data=skip(read_conll(sys.stdin,args.max),args.skip)
batch=500000
counter=0
while True:
stats=SymbolStats()
db_name=args.dir+"/%s_%05d.db"%(args.prefix,counter)
if os.path.exists(db_name):
os.system("rm -f "+db_name)
conn=sqlite3.connect(db_name)
prepare_tables(conn)
it=itertools.islice(src_data,batch)
filled=fill_db(conn,it,stats)
if filled==0:
os.system("rm -f "+db_name)
save_stats(stats)
break
build_indices(conn)
conn.close()
counter+=1
save_stats(stats)