-
Notifications
You must be signed in to change notification settings - Fork 699
/
convert_func.py
172 lines (145 loc) · 6.41 KB
/
convert_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import sys
sys.path.append("./")
import argparse
import json
import random
random.seed(42)
from convert.utils.instruction import instruction_mapper
from convert.utils.utils import stable_hash, write_to_json
from convert.processer import get_processer
from convert.converter import get_converter
def multischema_split_by_num_test(schemas, split_num=4):
if len(schemas) < split_num or split_num == -1:
return [schemas, ]
negative_length = max(len(schemas) // split_num, 1) * split_num
total_schemas = []
for i in range(0, negative_length, split_num):
total_schemas.append(schemas[i:i+split_num])
remain_len = max(1, split_num // 2)
if len(schemas) - negative_length >= remain_len:
total_schemas.append(schemas[negative_length:])
else:
total_schemas[-1].extend(schemas[negative_length:])
return total_schemas
def multischema_construct_instruction(task, language, schema1, text):
instruction = {
"instruction":instruction_mapper[task+language],
"schema":schema1,
"input":text,
}
return json.dumps(instruction, ensure_ascii=False)
def get_test_data(datas, processer, options):
results = []
for record in datas:
iid = stable_hash(record['text'])
task_record = processer.get_task_record(record)
schemas = processer.get_schemas(task_record)
if schemas is None:
continue
total_schemas = multischema_split_by_num_test(schemas, options.split_num)
for schema in total_schemas:
sinstruct = multischema_construct_instruction(options.task, options.language, schema, record['text'])
record2 = {
'id': iid,
'task': options.task,
'source': options.source,
'instruction': sinstruct,
}
if task_record is not None:
record2['label'] = json.dumps(task_record, ensure_ascii=False)
results.append(record2)
return results
def convert_output(converter, text, schemas, task_record):
output_texts = []
if len(schemas) == 0:
return output_texts
label_dict = converter.get_label_dict(task_record)
for schema in schemas:
output_text = converter.convert(
text, label_dict, s_schema1=schema
)
output_texts.append(output_text)
return output_texts
def get_train_data(datas, processer, converter, options):
results = []
for record in datas:
if options.cluster_mode:
total_schemas = processer.negative_cluster_sample(record, options.split_num, options.random_sort)
else:
total_schemas = processer.negative_sample(record, options.split_num, options.random_sort)
task_record = processer.get_task_record(record)
output_texts = convert_output(converter, record['text'], total_schemas, task_record) # 按照split_num切分schema和output_text
for schema, output_text in zip(total_schemas, output_texts):
sinstruct = multischema_construct_instruction(options.task, options.language, schema, record['text'])
record2 = {
'task': options.task,
'source': options.source,
'instruction': sinstruct,
'output': output_text
}
results.append(record2)
return results
def process(options):
converter = get_converter(options.task)(options.language, NAN='NAN')
processer_class = get_processer(options.task)
processer = processer_class.read_from_file(
processer_class, options.schema_path, negative=-1
)
if options.cluster_mode:
processer.set_hard_dict(json.load(open(options.hard_negative_path, 'r')))
processer.set_negative(options.neg_schema)
options.source = options.src_path.split('/')[-2] # 用源路径的最后一个文件夹名作为source
datas = processer.read_data(options.src_path)
if options.split == 'test':
results = get_test_data(datas, processer, options)
else:
results = get_train_data(datas, processer, converter, options)
write_to_json(options.tgt_path, results)
'''
测试集数据生成:
python ie2instruction/convert_func.py \
--src_path data/NER/sample.json \
--tgt_path data/NER/test.json \
--schema_path data/NER/schema.json \
--language zh \
--task NER \
--split_num 6 \
--split test
训练集数据生成:
python ie2instruction/convert_func.py \
--src_path data/NER/sample.json \
--tgt_path data/NER/train.json \
--schema_path data/NER/schema.json \
--language zh \
--task NER \
--split_num 6 \
--random_sort \
--split train
难负样本训练集数据生成:
python ie2instruction/convert_func.py \
--src_path data/SPO/sample.json \
--tgt_path data/SPO/train.json \
--schema_path data/SPO/schema.json \
--cluster_mode \
--hard_negative_path data/hard_negative/SPO_DuIE2.0.json \
--language zh \
--task SPO \
--split_num 4 \
--random_sort \
--split train
'''
if __name__ == "__main__":
parse = argparse.ArgumentParser()
parse.add_argument("--src_path", type=str, default="data/NER/sample.json")
parse.add_argument("--tgt_path", type=str, default="data/NER/processed.json")
parse.add_argument("--schema_path", type=str, default='data/NER/schema.json')
parse.add_argument("--hard_negative_path", type=str, default=None)
parse.add_argument("--cluster_mode", action='store_true', help="是否使用cluster模式, 负样本只包括难负样本+split_num个其他负样本")
parse.add_argument("--language", type=str, default='zh', choices=['zh', 'en'], help="不同语言使用的template及转换脚本不同")
parse.add_argument("--task", type=str, default="NER", choices=['RE', 'NER', 'EE', 'EET', 'EEA', 'SPO', 'KG'])
parse.add_argument("--split", type=str, default='train', choices=['train', 'test'])
parse.add_argument("--split_num", type=int, default=4, help="单个指令中最大schema数量。默认为4, -1表示不切分, 各个任务推荐的切分数量不同: NER:6, RE:4, EE:4, EET:4, EEA:4, KG:1")
parse.add_argument("--neg_schema", type=float, default=1, help="指令中负样本的比例, 默认为1, 即采用全部负样本")
parse.add_argument("--random_sort", action='store_true', help="是否对指令中的schema随机排序, 默认为False, 即按字母顺序排序")
options = parse.parse_args()
process(options)