forked from AliaksandrSiarohin/monkey-net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathframes_dataset.py
131 lines (101 loc) · 4.55 KB
/
frames_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
from skimage import io, img_as_float32
from skimage.color import gray2rgb
from sklearn.model_selection import train_test_split
from imageio import mimread
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
from augmentation import AllAugmentationTransform, VideoToTensor
def read_video(name, image_shape):
if name.lower().endswith('.png') or name.lower().endswith('.jpg'):
image = io.imread(name)
if len(image.shape) == 2 or image.shape[2] == 1:
image = gray2rgb(image)
if image.shape[2] == 4:
image = image[..., :3]
image = img_as_float32(image)
video_array = np.moveaxis(image, 1, 0)
video_array = video_array.reshape((-1,) + image_shape)
video_array = np.moveaxis(video_array, 1, 2)
elif name.lower().endswith('.gif') or name.lower().endswith('.mp4'):
video = np.array(mimread(name))
if len(video.shape) == 3:
video = np.array([gray2rgb(frame) for frame in video])
if video.shape[-1] == 4:
video = video[..., :3]
video_array = img_as_float32(video)
else:
raise Exception("Unknown file extensions %s" % name)
return video_array
class FramesDataset(Dataset):
"""Dataset of videos, videos can be represented as an image of concatenated frames, or in '.mp4','.gif' format"""
def __init__(self, root_dir, augmentation_params, image_shape=(64, 64, 3), is_train=True,
random_seed=0, pairs_list=None, transform=None):
self.root_dir = root_dir
self.images = os.listdir(root_dir)
self.image_shape = tuple(image_shape)
self.pairs_list = pairs_list
if os.path.exists(os.path.join(root_dir, 'train')):
assert os.path.exists(os.path.join(root_dir, 'test'))
print("Use predefined train-test split.")
train_images = os.listdir(os.path.join(root_dir, 'train'))
test_images = os.listdir(os.path.join(root_dir, 'test'))
self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
else:
print("Use random train-test split.")
train_images, test_images = train_test_split(self.images, random_state=random_seed, test_size=0.2)
if is_train:
self.images = train_images
else:
self.images = test_images
if transform is None:
if is_train:
self.transform = AllAugmentationTransform(**augmentation_params)
else:
self.transform = VideoToTensor()
else:
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.images[idx])
video_array = read_video(img_name, image_shape=self.image_shape)
out = self.transform(video_array)
# add names
out['name'] = os.path.basename(img_name)
return out
class PairedDataset(Dataset):
"""
Dataset of pairs for transfer.
"""
def __init__(self, initial_dataset, number_of_pairs, seed=0):
self.initial_dataset = initial_dataset
pairs_list = self.initial_dataset.pairs_list
np.random.seed(seed)
if pairs_list is None:
max_idx = min(number_of_pairs, len(initial_dataset))
nx, ny = max_idx, max_idx
xy = np.mgrid[:nx, :ny].reshape(2, -1).T
number_of_pairs = min(xy.shape[0], number_of_pairs)
self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)
else:
images = self.initial_dataset.images
name_to_index = {name: index for index, name in enumerate(images)}
pairs = pd.read_csv(pairs_list)
pairs = pairs[np.logical_and(pairs['source'].isin(images), pairs['driving'].isin(images))]
number_of_pairs = min(pairs.shape[0], number_of_pairs)
self.pairs = []
self.start_frames = []
for ind in range(number_of_pairs):
self.pairs.append(
(name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
pair = self.pairs[idx]
first = self.initial_dataset[pair[0]]
second = self.initial_dataset[pair[1]]
first = {'driving_' + key: value for key, value in first.items()}
second = {'source_' + key: value for key, value in second.items()}
return {**first, **second}