-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy paths6_dataset_0_convert_generative.py
85 lines (68 loc) · 3.69 KB
/
s6_dataset_0_convert_generative.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from os.path import join
import numpy as np
import zipstream
from api.ceb import CEBApi
from api.ldc import LdcAPI
from core.dataset.pairs_iterator import common_iter_dialogs
from core.dataset.pairs_with_candidates import provide_formatted_pairs
from core.spectrums.io_utils import SpectrumIOUtils
from core.utils_math import random_choice_non_repetitive
from e_pairs.cfg_hla import HlaExperimentConfig
from e_pairs.cfg_spectrum import SpectrumConfig
from utils import DATA_DIR
if __name__ == '__main__':
ldc_api = LdcAPI()
z = zipstream.ZipFile()
dataset_filepaths = {part_name: ldc_api.dataset_fold_filepath.format(fold_index=part_name)
for part_name in LdcAPI.dataset_folding_fixed_parts}
ceb_api = CEBApi(books_root=join(DATA_DIR, "books"), char_map_path=join(DATA_DIR, "chr_map.json"))
ceb_api.read_char_map()
hla_cfg = HlaExperimentConfig(books_storage=LdcAPI.books_storage)
speaker_spectrums = SpectrumIOUtils.read(hla_cfg.hla_prompts_filepath)
spectrum_cfg = SpectrumConfig()
TRAITS_NO = "original"
TRAITS_SPECTRUM = "spectrum"
traits_provider = {
TRAITS_NO: lambda your_id, partner_id: [None] * spectrum_cfg.spectrum_per_user_count,
TRAITS_SPECTRUM: lambda your_id, partner_id:
random_choice_non_repetitive(v=speaker_spectrums[partner_id]["prompts"],
size=spectrum_cfg.spectrum_per_user_count,
p=np.absolute(speaker_spectrums[partner_id]["weights"]),
to_list=True, take_less=True)
if partner_id in speaker_spectrums else traits_provider[TRAITS_NO](your_id, partner_id)
}
CANDIDATES_UNIFORM = ""
CANDIDATES_HLA_CLUSTER = "clustered"
candidates_provider = {
"no-cand": lambda _: None,
}
for data_fold_type, data_fold_source in dataset_filepaths.items():
for trait_type, traits_func in traits_provider.items():
for candidates_type, candidate_dict_func in candidates_provider.items():
if trait_type == TRAITS_NO and candidates_type == CANDIDATES_HLA_CLUSTER:
# This type does not makes sense, so we skip such formatting.
continue
if trait_type == TRAITS_SPECTRUM and candidates_type == CANDIDATES_UNIFORM and data_fold_type == "train":
continue
if candidates_type == CANDIDATES_HLA_CLUSTER and data_fold_type != "train":
# We consider HLA clustering and candidates selection only for training.
continue
args = [data_fold_type, trait_type]
if candidates_type != "":
args.append(candidates_type)
# There is no need to perform oversampling for non-train dataset type.
oversample_factor = None if data_fold_type != "train" else \
ldc_api.parlai_dataset_train_candidates_oversample_factor
data_it = provide_formatted_pairs(
dialogs_iter=common_iter_dialogs(data_fold_source),
traits_func=traits_func,
candidates_provider=candidate_dict_func(data_fold_type),
candidates_oversample_factor=oversample_factor)
z = zipstream.ZipFile()
filename = '{}.txt'.format("_".join(args))
z.write_iter(filename, data_it)
target = ldc_api.parlai_dataset_filepath.format(filename)
with open(ldc_api.parlai_dataset_filepath.format(filename), "wb") as f:
for episode_line in z:
f.write(episode_line)
print("Saved: {}".format(target))