forked from piggy2008/ImageEnhance
-
Notifications
You must be signed in to change notification settings - Fork 0
/
H5FileDataLoader.py
43 lines (32 loc) · 1.38 KB
/
H5FileDataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch.utils.data as data
import torch
import h5py
import numpy as np
class DatasetFromHdf5(data.Dataset):
def __init__(self, file_path):
super(DatasetFromHdf5, self).__init__()
hf = h5py.File(file_path)
self.data = hf.get('data')
self.target = hf.get('label')
def __getitem__(self, index):
return torch.from_numpy(np.expand_dims(self.data[index,0,:,:], axis=0)).float(), torch.from_numpy(np.expand_dims(self.target[index,0,:,:], axis=0)).float()
def __len__(self):
return self.data.shape[0]
class DatasetFromHdf5_clone(data.Dataset):
def __init__(self, file_path):
super(DatasetFromHdf5, self).__init__()
hf = h5py.File(file_path)
self.data = hf.get('data')
self.target = hf.get('label')
def __getitem__(self, index):
return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
def __len__(self):
return self.data.shape[0]
if __name__ == '__main__':
dataset = DatasetFromHdf5('/home/ty/code/pytorch-edsr/data/edsr_x4.h5')
dataLoader = torch.utils.data.DataLoader(dataset,
batch_size=32,
shuffle=True,
num_workers=int(1))
for i, (data, gt) in enumerate(dataLoader):
print(i)