-
Notifications
You must be signed in to change notification settings - Fork 4
/
csdata_fast.py
101 lines (80 loc) · 3.19 KB
/
csdata_fast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import imageio
import os
import torch
import glob
import random
import numpy as np
from functools import lru_cache
ycbcr_from_rgb = torch.tensor([[ 65.481, 128.553, 24.966],
[ -37.797, -74.203, 112.0],
[ 112.0, -93.786, -18.214]])
def rgb2ycbcr(rgb):
arr = rgb.float() / 255.0 @ ycbcr_from_rgb.transpose(1,0)
arr[..., 0] += 16
arr[..., 1] += 128
arr[..., 2] += 128
return arr
class SlowDataset(Dataset):
def __init__(self, args, train=True): # __init__是初始化该类的一些基础参数
super(SlowDataset, self).__init__()
self.args = args
self.train = train
self.bin_file_names = list()
self.image_folder = os.path.join('.', args.data_dir, args.train_name)
self.bin_image_folder = os.path.join('.', args.data_dir, args.train_name+'bin')
if os.path.exists(self.bin_image_folder):
self.file_names = glob.glob(os.path.join(self.bin_image_folder, '*.npy'))
self.bin_file_names = self.file_names
else:
os.makedirs(self.bin_image_folder, exist_ok=True)
self.ext = '/*%s' % args.ext
self.file_names = glob.glob(self.image_folder + self.ext)
self.prepare_cache()
self.data_copy = args.data_copy
def prepare_cache(self):
for fname in self.file_names:
bin_fname = fname.replace(self.image_folder, self.bin_image_folder).replace(self.args.ext, '.npy')
self.bin_file_names.append(bin_fname)
if not os.path.exists(bin_fname):
img = imageio.imread(fname)
np.save(bin_fname, img)
print('%s prepared!' % (bin_fname))
def __len__(self):
return len(self.file_names) * self.data_copy
@lru_cache(maxsize=400)
def get_ndarray(self, fname):
return np.load(fname)
def __getitem__(self, index):
rgb_range = self.args.rgb_range
n_channels = self.args.n_channels
img = torch.Tensor(self.get_ndarray(self.bin_file_names[index % len(self.file_names)]))
if img.numpy().ndim == 2:
img = img.unsqueeze(2)
c = img.shape[2]
# input rgb image output y chanel
if n_channels == 1 and c == 3:
img = rgb2ycbcr(img)[:, :, 0].unsqueeze(2)
elif n_channels == 3 and c == 1:
img = img.repeat(1, 1, 3)
w, h, _ = img.shape
th = tw = self.args.patch_size
i = random.randint(0, w - tw)
j = random.randint(0, h - th)
img = img[i:i + tw, j:j + th, :]
img_tensor = img.permute(2, 0, 1)
img_tensor = img_tensor * rgb_range / 255.0
img_tensor = self.augment(img_tensor)
return img_tensor
def augment(self, img, hflip=True, rot=True):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
if hflip:
img = img.flip([1])
if vflip:
img = img.flip([0])
if rot90:
img = img.permute(0, 2, 1)
return img