-
Notifications
You must be signed in to change notification settings - Fork 22
/
dataset.py
118 lines (96 loc) · 3.74 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# -*- coding: utf-8 -*-
"""
create on 18 Sep, 2019
@author: wangshuo
Reference: https://github.com/lijingsdu/sessionRec_NARM/blob/master/data_process.py
"""
import pickle
import torch
from torch.utils.data import Dataset
import numpy as np
def load_data(root, valid_portion=0.1, maxlen=19, sort_by_len=False):
'''Loads the dataset
:type path: String
:param path: The path to the dataset (here RSC2015)
:type n_items: int
:param n_items: The number of items.
:type valid_portion: float
:param valid_portion: The proportion of the full train set used for
the validation set.
:type maxlen: None or positive int
:param maxlen: the max sequence length we use in the train/valid set.
:type sort_by_len: bool
:name sort_by_len: Sort by the sequence lenght for the train,
valid and test set. This allow faster execution as it cause
less padding per minibatch. Another mechanism must be used to
shuffle the train set at each epoch.
'''
# Load the dataset
path_train_data = root + 'train.txt'
path_test_data = root + 'test.txt'
with open(path_train_data, 'rb') as f1:
train_set = pickle.load(f1)
with open(path_test_data, 'rb') as f2:
test_set = pickle.load(f2)
if maxlen:
new_train_set_x = []
new_train_set_y = []
for x, y in zip(train_set[0], train_set[1]):
if len(x) < maxlen:
new_train_set_x.append(x)
new_train_set_y.append(y)
else:
new_train_set_x.append(x[:maxlen])
new_train_set_y.append(y)
train_set = (new_train_set_x, new_train_set_y)
del new_train_set_x, new_train_set_y
new_test_set_x = []
new_test_set_y = []
for xx, yy in zip(test_set[0], test_set[1]):
if len(xx) < maxlen:
new_test_set_x.append(xx)
new_test_set_y.append(yy)
else:
new_test_set_x.append(xx[:maxlen])
new_test_set_y.append(yy)
test_set = (new_test_set_x, new_test_set_y)
del new_test_set_x, new_test_set_y
# split training set into validation set
train_set_x, train_set_y = train_set
n_samples = len(train_set_x)
sidx = np.arange(n_samples, dtype='int32')
np.random.shuffle(sidx)
n_train = int(np.round(n_samples * (1. - valid_portion)))
valid_set_x = [train_set_x[s] for s in sidx[n_train:]]
valid_set_y = [train_set_y[s] for s in sidx[n_train:]]
train_set_x = [train_set_x[s] for s in sidx[:n_train]]
train_set_y = [train_set_y[s] for s in sidx[:n_train]]
(test_set_x, test_set_y) = test_set
def len_argsort(seq):
return sorted(range(len(seq)), key=lambda x: len(seq[x]))
if sort_by_len:
sorted_index = len_argsort(test_set_x)
test_set_x = [test_set_x[i] for i in sorted_index]
test_set_y = [test_set_y[i] for i in sorted_index]
sorted_index = len_argsort(valid_set_x)
valid_set_x = [valid_set_x[i] for i in sorted_index]
valid_set_y = [valid_set_y[i] for i in sorted_index]
train = (train_set_x, train_set_y)
valid = (valid_set_x, valid_set_y)
test = (test_set_x, test_set_y)
return train, valid, test
class RecSysDataset(Dataset):
"""define the pytorch Dataset class for yoochoose and diginetica datasets.
"""
def __init__(self, data):
self.data = data
print('-'*50)
print('Dataset info:')
print('Number of sessions: {}'.format(len(data[0])))
print('-'*50)
def __getitem__(self, index):
session_items = self.data[0][index]
target_item = self.data[1][index]
return session_items, target_item
def __len__(self):
return len(self.data[0])