-
Notifications
You must be signed in to change notification settings - Fork 7
/
Dataset.py
74 lines (64 loc) · 2.32 KB
/
Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import print_function, division
import numpy as np
from Cifar import Cifar
from NUS_21 import NUS_21
from Imagenet import Imagenet
class Dataset(object):
"""docstring for Dataset."""
def __init__(self, dataset, mode, batchSize, W, H):
self.mode = mode
print(dataset)
dataset = dataset.upper()
mode = mode.lower()
if dataset == "CIFAR":
self.data = Cifar(mode, W, H)
elif dataset == "IMAGENET":
self.data = Imagenet(mode, W, H)
elif dataset == "NUS":
self.data = NUS_21(mode, W, H)
else:
raise NameError("No datset named {0}".format(dataset))
self._current = 0
self._batchSize = batchSize
if self.mode == "train":
self.choice = np.random.permutation(self.data.SamplesCount)
else:
self.choice = np.arange(0, self.data.SamplesCount, 1)
def NextBatch(self):
idx = self.choice[self._current: (self._current + self._batchSize)]
self._index = idx
self._current += self._batchSize
# print("[{0}/{1}]".format(self._current, self.data.SamplesCount))
return self.data.Get(idx)
def Index(self, i, batchSize):
idx = np.arange(i*batchSize, (i+1)*batchSize, 1)
return self.data.Get(idx)
@property
def EpochComplete(self):
complete = (self._current + self._batchSize) > self.data.SamplesCount
if complete:
self._current = 0
if self.mode == "train":
self.choice = np.random.permutation(self.data.SamplesCount)
return complete
@property
def Progress(self):
return self._current / self.data.SamplesCount
@staticmethod
def PreparetoEval(setName, W, H):
setName = setName.upper()
print(setName)
if setName == "NUS":
database = NUS_21('database', W, H)
query = NUS_21('query', W, H)
elif setName == "CIFAR":
database = Cifar("database", W, H)
query = Cifar("query", W, H)
elif setName == 'IMAGENET':
database = Imagenet('database', W, H)
query = Imagenet('query', W, H)
queryX = query.GetX()
queryY = query.Y
return queryX, queryY, database