-
Notifications
You must be signed in to change notification settings - Fork 46
/
split_data.py
96 lines (72 loc) · 2.96 KB
/
split_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# -*- coding: utf-8 -*-
import copy
import random
import fire
import numpy as np
from arena_util import load_json
from arena_util import write_json
class ArenaSplitter:
def _split_data(self, playlists):
tot = len(playlists)
train = playlists[:int(tot*0.80)]
val = playlists[int(tot*0.80):]
return train, val
def _mask(self, playlists, mask_cols, del_cols):
q_pl = copy.deepcopy(playlists)
a_pl = copy.deepcopy(playlists)
for i in range(len(playlists)):
for del_col in del_cols:
q_pl[i][del_col] = []
if del_col == 'songs':
a_pl[i][del_col] = a_pl[i][del_col][:100]
elif del_col == 'tags':
a_pl[i][del_col] = a_pl[i][del_col][:10]
for col in mask_cols:
mask_len = len(playlists[i][col])
mask = np.full(mask_len, False)
mask[:mask_len//2] = True
np.random.shuffle(mask)
q_pl[i][col] = list(np.array(q_pl[i][col])[mask])
a_pl[i][col] = list(np.array(a_pl[i][col])[np.invert(mask)])
return q_pl, a_pl
def _mask_data(self, playlists):
playlists = copy.deepcopy(playlists)
tot = len(playlists)
song_only = playlists[:int(tot * 0.3)]
song_and_tags = playlists[int(tot * 0.3):int(tot * 0.8)]
tags_only = playlists[int(tot * 0.8):int(tot * 0.95)]
title_only = playlists[int(tot * 0.95):]
print(f"Total: {len(playlists)}, "
f"Song only: {len(song_only)}, "
f"Song & Tags: {len(song_and_tags)}, "
f"Tags only: {len(tags_only)}, "
f"Title only: {len(title_only)}")
song_q, song_a = self._mask(song_only, ['songs'], ['tags'])
songtag_q, songtag_a = self._mask(song_and_tags, ['songs', 'tags'], [])
tag_q, tag_a = self._mask(tags_only, ['tags'], ['songs'])
title_q, title_a = self._mask(title_only, [], ['songs', 'tags'])
q = song_q + songtag_q + tag_q + title_q
a = song_a + songtag_a + tag_a + title_a
shuffle_indices = np.arange(len(q))
np.random.shuffle(shuffle_indices)
q = list(np.array(q)[shuffle_indices])
a = list(np.array(a)[shuffle_indices])
return q, a
def run(self, fname):
random.seed(777)
print("Reading data...\n")
playlists = load_json(fname)
random.shuffle(playlists)
print(f"Total playlists: {len(playlists)}")
print("Splitting data...")
train, val = self._split_data(playlists)
print("Original train...")
write_json(train, "orig/train.json")
print("Original val...")
write_json(val, "orig/val.json")
print("Masked val...")
val_q, val_a = self._mask_data(val)
write_json(val_q, "questions/val.json")
write_json(val_a, "answers/val.json")
if __name__ == "__main__":
fire.Fire(ArenaSplitter)